Prhub

#24816 Add FlashInfer SM90 cutlass MXFP4 MoE backend (W4A16) for GPT-OSS + DeepSeek-V4

原始 PR 作者 yuan-luo 合并时间 2026-05-14 05:53 文件变更 8 提交数 5 评论 53 代码增减 +1542 / -3

执行摘要

为 GPT-OSS 和 DeepSeek-V4 添加 FlashInfer SM90 MXFP4 MoE 后端

FlashInfer PR #3084 引入了针对 SM90 的混合输入 cutlass MoE 内核,支持 W4A16 的 MXFP4 量化格式。本 PR 将其作为可选的 MoE 后端集成至 SGLang,使得 PD 分离场景下的 prefill 工作节点可以利用其更高的性能(+24-36% at M≥1024),而 decode 节点保留 Marlin 的最佳性能。此设计充分发挥了两者的优势,无需修改默认配置。

本 PR 值得所有关注 MoE 推理性能的工程师仔细阅读。其设计展示了如何将外部高效内核(FlashInfer)集成到现有量化框架中,并保持与 Marlin 的兼容性。关键决策包括:通过 _fi_kernel 区分内核版本、在权重加载时预处理、利用 PD 分离策略发挥各自优势。Review 中关于正确性参数的讨论也具有实践参考价值。建议在 H100/H200 上测试 PD 场景。

讨论亮点
  • samuellees 指出 DeepSeek-V4 可能受 SwiGLU clamp、TP behavior、routed scaling factor、checkpoint layout 等影响,建议增加准确率验证。yuan-luo 修复了 SwiGLU 参数传递(明确传递 alpha=1, beta=0, limit),确认 TP 行为使用真实 tp_size/tp_rank,routed scaling factor 通过 maybe_fuse_routed_scale_and_shared_add 正确处理,并补充了 GSM8K (0.985) 和 GPQA Diamond 准确率测试结果。
  • gemini-code-assist[bot] 建议使用 layer.num_local_experts 和动态设备创建 SwiGLU 张量,以及确保 tensor 连续性。yuan-luo 采纳了关于专家数的建议,并解释了连续性在填充缓冲区中天然满足。
  • Fridge003 建议将单元测试移至 test/registered/unit/layers/quantization 并作为 stage-b 注册,以及添加端到端测试。yuan-luo 照做。
  • 关于 TLLM_LOG_LEVEL 设置,Fridge003 询问是否会引发日志泛滥,yuan-luo 解释这是为了抑制 FlashInfer 0.6.11 的调试日志,无害。

实现拆解

  1. 导入与版本检测:在 mxfp4.py 中导入 FlashInfer 的 SM90 混合输入辅助函数(interleave_moe_weights_for_sm90_mixed_gemminterleave_moe_scales_for_sm90_mixed_gemm),并通过 try/except 提供版本守卫。新增模块级标志 _FI_HAS_SM90_CUTLASS_MXFP4

  2. GPT-OSS 路径扩展mxfp4.py):在 Mxfp4MoEMethod.__init__ 中根据 SM 版本选择内核(_fi_kerneltrtllm_sm100cutlass_sm90)。在 create_weights 中为 SM90 路径将 intermediate_sizehidden_size 填充到 128 的倍数。新增 _process_weights_for_sm90_cutlass_apply_sm90_cutlass 方法,分别处理权重/缩放因子的字节交错和前向调用,并在 process_weights_after_loadingapply 中早期分发。

  3. DeepSeek-V4 后端(新建 mxfp4_flashinfer_cutlass_moe.py):定义 Mxfp4FlashinferCutlassMoEMethod,包含 process_weights_after_loading(权重重排序、转换 E8M0 缩放因子、字节交错)和 apply(调用 cutlass_fused_moe,并利用 maybe_fuse_routed_scale_and_shared_add 融合 routed scaling factor)。

  4. 调度与集成:在 fp8.pyget_quant_method 中添加 SM90/SM100 分发逻辑;在 topk.py 中为 BypassedTopKOutput 添加 to_standard 方法;扩展 mxfp4_flashinfer_trtllm_moe.py 中的 maybe_fuse_routed_scale_and_shared_add 以支持新类。

  5. 测试与基准:新增单元测试(test_mxfp4_sm90_cutlass.py)验证 GPT-OSS 路径的权重处理和前向位精确性;新增基准测试(bench_mxfp4_sm90_kernels.py)对比 Marlin 和 FlashInfer 延迟;在已有 DSv4 端到端测试(test_deepseek_v4_flash_fp4_h200.py)中添加 TestDSV4FlashFP4H200FlashInferCutlass 子测试。

文件 模块 状态 重要度
python/sglang/srt/layers/quantization/mxfp4_flashinfer_cutlass_moe.py MoE 后端 added 8.89
python/sglang/srt/layers/quantization/mxfp4.py 量化层 modified 8.69
test/registered/unit/layers/quantization/test_mxfp4_sm90_cutlass.py 单元测试 added 7.48
python/sglang/test/bench_mxfp4_sm90_kernels.py 基准测试 added 7.49
test/registered/dsv4/test_deepseek_v4_flash_fp4_h200.py 集成测试 modified 6.4
python/sglang/srt/layers/moe/topk.py 路由层 modified 5.4
python/sglang/srt/layers/quantization/fp8.py 量化配置 modified 5.48
python/sglang/srt/layers/quantization/mxfp4_flashinfer_trtllm_moe.py 量化层 modified 4.94

关键符号

Mxfp4FlashinferCutlassMoEMethod.__init__ Mxfp4FlashinferCutlassMoEMethod.process_weights_after_loading Mxfp4FlashinferCutlassMoEMethod.apply Mxfp4MoEMethod._process_weights_for_sm90_cutlass Mxfp4MoEMethod._apply_sm90_cutlass Mxfp4MoEMethod.__init__. 内核分发 _flashinfer_has_sm90_cutlass_mxfp4

关键源码片段

python/sglang/srt/layers/quantization/mxfp4_flashinfer_cutlass_moe.py dependency-wiring

新增核心类 Mxfp4FlashinferCutlassMoEMethod,实现 DeepSeek-V4 的 SM90 cutlass MXFP4 权重处理与推理。

class Mxfp4FlashinferCutlassMoEMethod:
    """DeepSeek-V4 W4A16 MXFP4 MoE 后端,基于 FlashInfer 的 SM90 mixed-input cutlass
    分组 GEMM。融合内核在一个调用中完成 GEMM1 + 受限 SwiGLU + GEMM2。
    权重和缩放因子在加载时进行一次性预处理。"""
​
    def __init__(self, fp8_method, prefix: str):
        # 检查 FlashInfer 版本是否包含 SM90 混合输入辅助函数
        if not _FI_HAS_SM90_CUTLASS_MXFP4:
            raise RuntimeError(
                "Mxfp4FlashinferCutlassMoEMethod 需要 FlashInfer >= 0.6.11 "
                "(PR #3084 SM90 mixed-input helpers)。"
            )
        self._fp8 = fp8_method
        self.prefix = prefix
        # SwiGLU 参数张量,供 fused 内核使用
        self._swiglu_alpha_tensor: torch.Tensor | None = None
        self._swiglu_beta_tensor: torch.Tensor | None = None
        self._swiglu_limit_tensor: torch.Tensor | None = None
python/sglang/srt/layers/quantization/mxfp4.py dependency-wiring

修改 Mxfp4MoEMethod 类,新增 SM90 cutlass 路径的权重处理函数和前向函数,支持 GPT-OSS 的 FlashInfer 新后端。

# 在 Mxfp4MoEMethod.__init__ 中添加的内核分发逻辑:
# 根据 GPU 架构选择 FlashInfer 入口点
self._fi_kernel: Optional[str] = None
if self.use_flashinfer:
    if is_sm100_supported():
        self._fi_kernel = "trtllm_sm100"
    elif is_sm90_supported():
        if not _FI_HAS_SM90_CUTLASS_MXFP4:
            raise RuntimeError(
                "moe_runner_backend=flashinfer_mxfp4 on SM90 requires the "
                "interleave_moe_{weights,scales}_for_sm90_mixed_gemm helpers "
                "from FlashInfer PR #3084 (>= 0.6.11). Upgrade flashinfer-python "
                "or pick a different backend (e.g. marlin / triton_kernel)."
            )
        self._fi_kernel = "cutlass_sm90"
    else:
        raise NotImplementedError(
            "moe_runner_backend=flashinfer_mxfp4 requires SM90 or SM100."
        )

评论区精华

DeepSeek-V4 正确性验证 正确性

samuellees 指出 SwiGLU clamp、TP behavior、routed scaling factor、checkpoint layout 可能影响 DSv4 准确率,建议运行 GSM8k 等基准。

结论:yuan-luo 修复了 SwiGLU 参数传递,确认 TP 行为和 routed scaling factor 正确,并补充了 GSM8K (0.985) 和 GPQA Diamond 结果,验证了数值正确性。 · 已解决

SwiGLU 参数修复 正确性

GPT-OSS 路径已使用显式 alpha=1.702, beta=1.0, limit=7.0,但 DSv4 路径最初传递 alpha=None, beta=None,导致混合分支行为。yuan-luo 修复为 alpha=ones, beta=zeros。

结论:修复完成,DSv4 路径现在使用与 trtllm 路径一致的参数。 · 已解决

EP 支持与设备分配 正确性

gemini-code-assist[bot] 建议创建 SwiGLU 张量时使用 layer.num_local_experts 代替 self.num_experts,以及使用权重张量的设备而非硬编码 'cuda'。

结论:yuan-luo 采纳建议,使用本地专家数和动态设备。 · 已解决

Tensor 连续性 style

gemini-code-assist[bot] 建议确保传给 FlashInfer 的权重连续。yuan-luo 解释填充缓冲区连续,无需额外操作。

结论:确认在填充缓冲区中天然连续,无需修改。 · 已解决

测试位置与注册 测试

Fridge003 建议将单元测试移至 test/registered/unit/layers/quantization 并注册为 stage-b。

结论:yuan-luo 移动并注册。 · 已解决

风险与影响

  • 依赖风险:新后端需要 FlashInfer >= 0.6.11(PR #3084)。旧版本会抛出 RuntimeError,但可能阻止回退到 Marlin,用户需升级。
  • 性能折中:decode 阶段(M ≤ 64)FlashInfer 比 Marlin 慢 12-15%,建议仅对 prefill 节点启用,否则可能引起整体性能下降。
  • 形状约束:GPT-OSS 路径中的 hidden_size 和 intermediate_size 需为 128 的倍数,否则会报错填充;DeepSeek-V4 标准配置(7168, 2048)天然满足。
  • 正确性风险:SwiGLU 参数在不同模型中可能不同,新变体需仔细校对。routed scaling factor 融合逻辑已通过测试,但需持续验证。
  • EP 兼容性:Review 中已修复使用本地专家数和动态设备,但仍需确保未来 EP 配置下的正确性。
  • 用户影响:新增 --moe-runner-backend flashinfer_mxfp4 选项,在 H100/H200 上 prefill 性能提升显著,推荐用于 PD 分离的 prefill 节点。不影响默认 Marlin 路径。
  • 系统影响:新增 1 个源文件、1 个单元测试、1 个基准测试;修改 5 个现有源文件和 1 个测试文件。代码约 1500+ 行,但逻辑扩展条件分支,测试覆盖完整。
  • 团队影响:提供了集成外部 MoE 内核的参考模式,未来可复用相同架构支持更多 FlashInfer 内核。
依赖 FlashInfer >= 0.6.11 形状约束 128 对齐 decode 性能回退风险 SwiGLU 参数因模型而异 EP 配置需验证

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论