执行摘要
- 一句话:扩展 OAI Triton MoE ROCm 支持至 RDNA4
- 推荐动作:值得精读:本 PR 展示了在跨平台代码中处理设备功能检测的简洁方法,避免了 capability 元组因供应商实现不同而产生的别名问题。关注点:如何通过字符串匹配(
on_gfx1x)避免硬编码 capability 数字,以及如何通过集中化辅助函数消除重复。
功能与动机
根据 PR body,gfx12 (RDNA4) 映射到 capability (12,0),先前被 (9,0) <= cap < (11,0) 排除。Triton 3.5+ 通过 DecomposeScaledBlocked 支持 tl.dot_scaled,因此标准 MXFP4 MoE 路径无需定制内核即可工作。本 PR 替换了 #34632 尝试添加自定义反量化内核的思路,因为实际测试发现并不需要。
实现拆解
1. 提取共享设备检测逻辑
在 vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py 中新增私有函数 _triton_kernel_moe_supports_current_device(),将原本内联在 _supports_current_device 方法中的设备检查逻辑抽取到单一函数中。
2. CUDA 路径保持不变
CUDA 平台下,函数仍使用 (9,0) <= cap < (11,0) 窗口,覆盖 Hopper (SM90) 和 Blackwell (SM100),排除 SM120 及更早架构。
3. ROCm 路径使用平台辅助函数
ROCm 平台下,函数委托给 vllm.platforms.rocm 中已有的 on_gfx9()(gfx90a/gfx942/gfx950,排除 gfx906/gfx908)和 on_gfx1x()(gfx11xx/gfx12xx,排除 gfx10xx),精确匹配已知支持 Triton MoE 的 AMD GPU。
4. 更新专家类的 _supports_current_device
BaseOAITritonExperts 和 OAITritonMxfp4ExpertsMonolithic 的静态方法均改为调用 _triton_kernel_moe_supports_current_device() 并合取 has_triton_kernels(),消除重复代码。
5. 移除 oracle 模块中的重复检查
oracle/mxfp4.py 中的后端选择器改为从 experts 模块导入同一辅助函数,避免了设备检查逻辑的副本(最终提交中该文件不再修改,仅保留一个定义)。
测试验证
作者在 AMD Radeon AI PRO R9700 (gfx1201) 上加载 openai/gpt-oss-20b 模型并通过 curl 验证了推理输出正确。
关键文件:
vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py(模块 MoE 专家;类别 source;类型 data-contract;符号 _triton_kernel_moe_supports_current_device, BaseOAITritonExperts._supports_current_device, OAITritonMxfp4ExpertsMonolithic._supports_current_device): 唯一的变更文件,包含新增的设备检测辅助函数和两个专家类的修改,是整个 PR 的核心。
关键符号:_triton_kernel_moe_supports_current_device, BaseOAITritonExperts._supports_current_device, OAITritonMxfp4ExpertsMonolithic._supports_current_device
关键源码片段
vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py
唯一的变更文件,包含新增的设备检测辅助函数和两个专家类的修改,是整个 PR 的核心。
# 平台感知的设备支持检查函数,集中管理 OAI Triton MoE 的设备白名单
# 在 CUDA 上保持原有 (9,0)<=cap<(11,0) 窗口,在 ROCm 上使用
# 架构字符串辅助函数精确匹配,避免 capability 别名问题
def _triton_kernel_moe_supports_current_device() -> bool:
p = current_platform
if p.is_cuda():
cap = p.get_device_capability()
# CUDA 保持原有范围:Hopper SM90 / Blackwell SM100
return cap is not None and (9, 0) <= (cap.major, cap.minor) < (11, 0)
if p.is_rocm():
from vllm.platforms.rocm import on_gfx1x, on_gfx9
# gfx9 家族:MI200 (gfx90a), MI3xx (gfx942/gfx950),已排除 gfx906/gfx908
# gfx1x 家族:RDNA3 (gfx11xx) 和 RDNA4 (gfx12xx),排除 gfx10xx
return on_gfx9() or on_gfx1x()
return False
class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
@staticmethod
def _supports_current_device() -> bool:
# 委托给共享函数并确保 Triton 内核可用
return _triton_kernel_moe_supports_current_device() and has_triton_kernels()
class OAITritonMxfp4ExpertsMonolithic(BaseOAITritonExperts):
# 继承基类的 _supports_current_device,无需重复定义
pass
评论区精华
hongxiayang (要求更具体白名单): “gfx10xx (gfx1030) should not be included. Don't use >=9 and <13, which is not accurate.” → 最终改用 on_gfx9() / on_gfx1x(),精确排除不受支持的架构。
BowenBao (消除重复): “could we keep one instance of this method and import it in the other file?” → 最终将定义集中在 experts 模块,从 oracle 模块导入,减少重复。
tjtanaa (保持 CUDA 条件): “Given that this PR is addressing ROCm related changes, I would prefer we keep the CUDA condition be the same as before.” → 最终 CUDA 路径恢复为 < (11,0),未拓宽。
gemini-code-assist[bot] (自动审查早期建议): 建议使用更具体的 (9,0)<=cap<(11,0) or cap.major==12 来避免误支持 gfx11,但后续人工审查认为 gfx11 同样受支持,最终通过平台辅助函数解决了精确性问题。
- 设备白名单精确性 (design): 改用
on_gfx9() / on_gfx1x() 精确匹配已知支持的 AMD GPU,排除 gfx10xx。
- 消除重复代码 (design): 将定义集中在 experts 模块的
_triton_kernel_moe_supports_current_device 中,oracle 模块通过从 experts 导入来使用。
- 保持 CUDA 条件不变 (design): 最终 CUDA 路径恢复为原有窗口,未引入任何 CUDA 侧改变。
- 函数命名清晰性 (style): 接受建议,最终函数名为
_triton_kernel_moe_supports_current_device(前导下划线表示内部使用)。
风险与影响
- 风险:
- 回归风险低:CUDA 路径未变;ROCm 路径依赖经过验证的平台辅助函数(
on_gfx9/on_gfx1x),排除了已知不支持的架构(gfx906/gfx908/gfx10xx)。
- 未覆盖的 ROCm 设备:若未来添加 gfx13xx 等新架构,
on_gfx1x() 通过子串匹配可能需要扩展;但当前范围明确。
- 无性能影响:仅修改设备检测逻辑,不涉及计算路径。
- 缺少自动化 CI 测试:虽在实机测试过,但无新增 CI 测试用例,后续重构可能遗漏此类平台检查。
- 影响:
- 用户影响:使用 AMD RDNA4(gfx12)的用户现在可以运行 MXFP4 MoE 模型(如
openai/gpt-oss-20b),之前无法加载。gfx11 (RDNA3) 用户不受影响(原有支持仍在)。
- 系统影响:仅影响 MoE 后端选择,不更改其他硬件路径。无性能退化预期。
- 团队影响:确立并文档化了“使用
on_gfx9/on_gfx1x 而非 capability 元组”的实践,方便后续添加新 ROCm 架构时参考。
- 风险标记:回归风险低, 缺少测试覆盖, ROCm 特有配置
关联脉络
- PR #34632 尝试为 RDNA4 添加自定义反量化内核: 本 PR 直接替换了 #34632 的方案,因为测试发现标准 Triton 路径在 RDNA4 上即可工作,无需自定义内核。
- PR #34032 早期添加 gfx11/gfx12 Triton MoE 支持: 该 PR 曾为 gfx11/gfx12 添加过支持,后来在重构中丢失。本 PR 恢复了该支持,并以更精确的方式实现。
参与讨论