Prhub

#26489 [MoE Refactor] Migrate SM90 Cutlass W4A16 to MoeRunner

原始 PR 作者 yuan-luo 合并时间 2026-05-30 17:02 文件变更 5 提交数 1 评论 1 代码增减 +281 / -116

执行摘要

迁移 SM90 cutlass MXFP4 到统一 MoeRunner

PR #24816 以私有 _apply_sm90_cutlass 实现了 SM90 cutlass MXFP4 路径,绕过了统一 MoeRunner。#25525 引入了 register_fused_func 池并迁移了 flashinfer_cutedsl,留下 flashinfer_mxfp4 作为下一个待清理的后端。此外,两个生产调用站点(GPT-OSS 的 mxfp4.py 和 DSv4 的 mxfp4_flashinfer_cutlass_moe.py)有约 80 行几乎相同的 FlashInfer 调用代码且已开始漂移。合并它们可以减少分歧,并为后续改进提供单一修改点。

值得精读,特别是对于理解 SGLang 的 MoE runner 架构演进和 FusedOpPool 设计模式。展示了如何通过注册机制将特定 kernel 路径统一到通用调度框架中。同时关注 gemini-code-assist 提出的空输入风险,建议在后续迭代中考虑添加防御性检查。

讨论亮点

只有一条 review 评论:由 gemini-code-assist[bot]flashinfer_mxfp4.py 第 117 行提出,当输入 x 为空(0 tokens)时,继续进行填充、对称内存分配和调用 FlashInfer kernel 可能导致不必要的开销或 CUDA 崩溃,建议添加提前返回。该建议未在后续讨论中得到认可,PR 最终由 ch-wanrainj-me 直接 approve,评论保持 unresolved。设计上认为空输入在推理中罕见,且 FlashInfer kernel 可能已处理该情况,因此没有修改。

实现拆解

步骤 1:新建 flashinfer_mxfp4 融合函数模块

  • 新增文件 python/sglang/srt/layers/moe/moe_runner/flashinfer_mxfp4.py
  • 定义 dataclass FlashInferMxfp4CutlassMoeQuantInfo(继承 MoeQuantInfo),承载预交错的权重/尺度、可选偏置和 SwiGLU 标量、TP/EP 拓扑以及 GPT-OSS 需要的填充维度。
  • 定义核心融合函数 fused_experts_none_to_flashinfer_mxfp4,通过 @register_fused_func('none', 'flashinfer_mxfp4') 注册到 FusedOpPool。该函数从 dispatch_output 获取 hidden_states 和 topk,处理 bypassed topk,调用 FlashInfer 的 cutlass_fused_moe(use_w4_group_scaling=True),并利用对称内存进行输出分配。
  • 支持 GPT-OSS 的输入填充/输出修剪逻辑和 DSv4 的 SwiGLU 标量可选传递。

步骤 2:改造 Mxfp4MoEMethod(GPT-OSS)

  • 修改 python/sglang/srt/layers/quantization/mxfp4.py
  • create_moe_runner 中为 flashinfer_mxfp4_fi_kernel == 'cutlass_sm90' 添加分支:导入新模块触发注册,然后创建 MoeRunner。之前该分支走 pass(空操作)。
  • 重写 _apply_sm90_cutlass:去掉所有直接 FlashInfer 调用代码,改为构建 FlashInferMxfp4CutlassMoeQuantInfo 并调用 self.runner.run(dispatch_output, quant_info)。同时移除不再需要的导入(如 cutlass_fused_moeActivationType)。

步骤 3:改造 Mxfp4FlashinferCutlassMoEMethod(DSv4)

  • 修改 python/sglang/srt/layers/quantization/mxfp4_flashinfer_cutlass_moe.py
  • create_moe_runner 中直接创建 MoeRunner(MoeRunnerBackend.FLASHINFER_MXFP4, moe_runner_config),之前为空。
  • apply 方法从直接调用 cutlass_fused_moe 改为构建 quant_info 并调用 self.runner.run,大幅缩减函数体。同时移除冗余导入。

步骤 4:更新通用 MoeRunner

  • 修改 python/sglang/srt/layers/moe/moe_runner/runner.py
  • MoeRunner.__init__ 中添加 elif runner_backend.is_flashinfer_mxfp4(): self.runner_core = None,表明该 backend 只支持融合路径。

步骤 5:适配测试用例

  • 修改 test/registered/unit/layers/quantization/test_mxfp4_sm90_cutlass.py
  • 新增辅助函数 _build_flashinfer_mxfp4_runner 构造真正的 MoeRunner 实例(绕过 create_moe_runner 对 server args 的依赖)。
  • 在测试方法 test_apply_sm90_cutlass_matches_flashinfer_direct 中扩展 monkeypatch 到新模块 flashinfer_mxfp4(屏蔽对称内存和 TP group),并调整调用签名使用 _MockDispatchOutput
文件 模块 状态 重要度
python/sglang/srt/layers/moe/moe_runner/flashinfer_mxfp4.py MoE 融合函数 added 8.64
python/sglang/srt/layers/quantization/mxfp4.py GPT-OSS 量化 modified 7.36
python/sglang/srt/layers/quantization/mxfp4_flashinfer_cutlass_moe.py DSv4 量化 modified 6.72
test/registered/unit/layers/quantization/test_mxfp4_sm90_cutlass.py 测试用例 modified 6.01
python/sglang/srt/layers/moe/moe_runner/runner.py 通用运行器 modified 4.66

关键符号

FlashInferMxfp4CutlassMoeQuantInfo fused_experts_none_to_flashinfer_mxfp4 Mxfp4MoEMethod.create_moe_runner Mxfp4MoEMethod._apply_sm90_cutlass Mxfp4FlashinferCutlassMoEMethod.create_moe_runner Mxfp4FlashinferCutlassMoEMethod.apply _build_flashinfer_mxfp4_runner

关键源码片段

python/sglang/srt/layers/quantization/mxfp4.py core-logic

GPT-OSS 量化方法改造:修改 create_moe_runner 和 _apply_sm90_cutlass,将 kernel 调用委托给 MoeRunner。

def create_moe_runner(self, layer, moe_runner_config):
    self.moe_runner_config = moe_runner_config
    moe_runner_backend = get_moe_runner_backend()
    if moe_runner_backend.is_auto():
        # auto selection logic...
        pass
​
    if moe_runner_backend.is_aiter():
        self.runner = MoeRunner(
            moe_runner_backend,
            replace(moe_runner_config, activation='swiglu')
        )
    elif (moe_runner_backend.is_triton_kernels()
          or moe_runner_backend.is_triton()
          or moe_runner_backend.is_marlin()):
        self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
    elif (moe_runner_backend.is_flashinfer_mxfp4()
          and self._fi_kernel == 'cutlass_sm90'):
        # NEW: register fused func and create MoeRunner
        import sglang.srt.layers.moe.moe_runner.flashinfer_mxfp4 # noqa: F401
        self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
    else:
        # Legacy bypass path (e.g. SM100 trtllm-gen) — not migrated yet
        passdef _apply_sm90_cutlass(self, layer, dispatch_output):
    # Build quant_info and delegate to MoeRunner
    from sglang.srt.layers.moe.moe_runner.flashinfer_mxfp4 import (
        FlashInferMxfp4CutlassMoeQuantInfo,
    )
​
    quant_info = FlashInferMxfp4CutlassMoeQuantInfo(
        w13_weight=layer.w13_weight,
        w2_weight=layer.w2_weight,
        w13_weight_scale=layer.w13_weight_scale,
        w2_weight_scale=layer.w2_weight_scale,
        w13_bias=layer.w13_weight_bias,
        w2_bias=layer.w2_weight_bias,
        swiglu_alpha=layer.swiglu_alpha,
        swiglu_beta=layer.swiglu_beta,
        swiglu_limit=layer.swiglu_limit,
        moe_tp_size=layer.moe_tp_size,
        moe_tp_rank=layer.moe_tp_rank,
        moe_ep_size=layer.moe_ep_size,
        moe_ep_rank=layer.moe_ep_rank,
        padded_hidden=self._padded_hidden,
    )
​
    # Delegate to the unified runner
    out = self.runner.run(dispatch_output, quant_info)
    return out
python/sglang/srt/layers/quantization/mxfp4_flashinfer_cutlass_moe.py dependency-wiring

DSv4 量化方法改造:apply 方法从直接调用 cutlass_fused_moe 改为使用 MoeRunner。

def create_moe_runner(self, layer, moe_runner_config):
    from sglang.srt.layers.moe.moe_runner.runner import MoeRunner
    from sglang.srt.layers.moe.utils import MoeRunnerBackend
​
    self.moe_runner_config = moe_runner_config
​
    # Set up SwiGLU tensors (clamped activation)
    swiglu_limit = getattr(moe_runner_config, 'swiglu_limit', None)
    if swiglu_limit is not None:
        E = layer.num_local_experts
        device = layer.w13_weight.device
        self._swiglu_alpha_tensor = torch.ones(E, dtype=torch.float32, device=device)
        self._swiglu_beta_tensor = torch.zeros(E, dtype=torch.float32, device=device)
        self._swiglu_limit_tensor = torch.full(
            (E,), swiglu_limit, dtype=torch.float32, device=device
        )
    else:
        self._swiglu_alpha_tensor = None
        self._swiglu_beta_tensor = None
        self._swiglu_limit_tensor = None
​
    # Register fused func and create MoeRunner
    import sglang.srt.layers.moe.moe_runner.flashinfer_mxfp4 # noqa: F401
    self.runner = MoeRunner(MoeRunnerBackend.FLASHINFER_MXFP4, moe_runner_config)def apply(self, layer, dispatch_output):
    # Build quant_info and delegate to MoeRunner
    from sglang.srt.layers.moe.moe_runner.flashinfer_mxfp4 import (
        FlashInferMxfp4CutlassMoeQuantInfo,
    )
    from sglang.srt.layers.moe.topk import TopKOutputChecker
​
    topk_output = dispatch_output.topk_output
    if not TopKOutputChecker.format_is_standard(topk_output):
        raise ValueError(f'Unsupported topk output format: {topk_output.format}')
​
    quant_info = FlashInferMxfp4CutlassMoeQuantInfo(
        w13_weight=layer.w13_weight,
        w2_weight=layer.w2_weight,
        w13_weight_scale=layer.w13_weight_scale,
        w2_weight_scale=layer.w2_weight_scale,
        # DSv4 has no bias; leave as None (default)
        w13_bias=None,
        w2_bias=None,
        swiglu_alpha=self._swiglu_alpha_tensor,
        swiglu_beta=self._swiglu_beta_tensor,
        swiglu_limit=self._swiglu_limit_tensor,
        moe_tp_size=layer.moe_tp_size,
        moe_tp_rank=layer.moe_tp_rank,
        moe_ep_size=layer.moe_ep_size,
        moe_ep_rank=layer.moe_ep_rank,
        padded_hidden=None, # DSv4 does not pad
    )
​
    return self.runner.run(dispatch_output, quant_info)
test/registered/unit/layers/quantization/test_mxfp4_sm90_cutlass.py test-coverage

测试适配:新增辅助函数构建真正的 MoeRunner,并更新 monkeypatch 覆盖新模块。

def _build_flashinfer_mxfp4_runner(num_experts, hidden, inter):
    # Construct a real MoeRunner bound to the flashinfer_mxfp4 fused func.
    # Bypasses create_moe_runner (needs live server arg context)
    # and wires with minimal MoeRunnerConfig.
    import sglang.srt.layers.moe.moe_runner.flashinfer_mxfp4 # noqa: F401
    from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
    from sglang.srt.layers.moe.moe_runner.runner import MoeRunner
    from sglang.srt.layers.moe.utils import MoeRunnerBackend
​
    cfg = MoeRunnerConfig(
        num_experts=num_experts,
        num_local_experts=num_experts,
        hidden_size=hidden,
        intermediate_size_per_partition=inter,
        top_k=None,
        activation='silu',
        is_gated=True,
    )
    return MoeRunner(MoeRunnerBackend.FLASHINFER_MXFP4, cfg)
​
​
# In the test function, monkeypatch is extended to the new module:
def test_apply_sm90_cutlass_matches_flashinfer_direct(...):
    import sglang.srt.layers.moe.moe_runner.flashinfer_mxfp4 as fi_mxfp4_mod
    # existing monkeypatch for mxfp4_mod...
    monkeypatch.setattr(fi_mxfp4_mod, 'use_symmetric_memory', lambda *a, **kw: nullcontext())
    monkeypatch.setattr(fi_mxfp4_mod, 'is_allocation_symmetric', lambda: False)
    monkeypatch.setattr(fi_mxfp4_mod, 'get_tp_group', lambda: None)
    # rest of test...

评论区精华

空输入处理 正确性

gemini-code-assist[bot] 评论:当输入 tensor 为空(0 tokens)时,继续进行 padding、对称内存分配和调用 FlashInfer kernel 可能导致不必要的开销或 CUDA crash,建议添加提前返回。

结论:建议未采纳,PR 保持现状。其他 reviewer 直接 approve,未讨论该问题。 · unresolved

风险与影响

  1. 一致性风险:新 fused func 必须与两处原始调用完全等价。GPT-OSS 的 padding 逻辑(input pad + output trim)和 DSv4 的 SwiGLU clamp 是否在统一的 fused func 中正确分支已被测试覆盖,但生产环境中更多模式(如 EP != 1)未在单元测试中覆盖。
  2. 边角情况:空输入(0 tokens)未做特殊处理,虽然不影响当前使用场景,但未来若调度器产生空 batch 可能导致 kernel 异常。
  3. 性能回归:引入 MoeRunner 间接调用增加了函数调用链和 quant_info 建造成本,但极微小。对称内存分配逻辑从 _apply_sm90_cutlass 移到了 fused func,路径一致。
  4. 导入时机:在 create_moe_runner 中 import flashinfer_mxfp4 模块触发注册,可能在意料之外的时间点引入模块加载开销,但注册只发生一次。
  5. 测试覆盖:现有测试验证了数学等价性,但未测试所有拓扑参数(如不同的 moe_tp_size、moe_ep_size)和不同输入形状。

用户:行为无变化,推理结果和性能应保持一致。
系统:代码架构更清晰,两个生产路径(GPT-OSS 和 DSv4)共享同一个融合函数,消除了重复代码和维护负担。后续 LoRA、DeepEP a2a 封装、autotune 缓存共享等改进只需在一个地方添加。
团队:降低了未来改动时的出错概率。需要维护者理解融合函数模块和量化方法之间的约定(quant_info 组合方式)。

核心路径变更 缺少空输入边角处理 测试覆盖有限 两模型一致性风险

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论