Prhub

#23285 [Flashinfer] Integrate flashinfer router gemm for sm103

原始 PR 作者 Fridge003 合并时间 2026-04-28 11:37 文件变更 1 提交数 1 评论 3 代码增减 +1 / -1

执行摘要

Flashinfer router gemm 支持 sm103

PR 动机是在 Blackwell Ultra(sm103)上利用 flashinfer 的优化 router gemm 提升路由 GEMM 性能。依赖 flashinfer >=0.6.8 及 flashinfer PR#2991。

值得合并,改动小而明确。建议关注后续 flashinfer 版本更新,确保兼容性。

讨论亮点

无 review 评论。

实现拆解

  1. 修改路由 GEMM 的条件判断:在 python/sglang/srt/models/deepseek_v2.pyDeepseekV2Gate.forward 方法中,将 _device_sm == 100 改为 _device_sm in [100, 103],使得 sm103 设备也能进入 flashinfer router gemm 分支(调用 flashinfer_dsv3_router_gemm)。
  2. 输出精度保证:该路径输出为 float32 类型,与原有 sm100 行为一致。
  3. 无需额外配置:不涉及配置文件、命令行参数或 API 变更。
文件 模块 状态 重要度
python/sglang/srt/models/deepseek_v2.py 模型层 modified 5.1

关键符号

DeepseekV2Gate.forward

关键源码片段

python/sglang/srt/models/deepseek_v2.py core-logic

核心变更文件,修改路由 GEMM 条件以支持 sm103。

# python/sglang/srt/models/deepseek_v2.py
# 路由 GEMM 前向函数中,选择算子时的条件判断
if (
    _is_cuda
    and hidden_states.shape[0] <= 16
    and hidden_states.shape[1] == 7168
    and (self.weight.shape[0] == 256 or self.weight.shape[0] == 384)
    and _device_sm >= 90
):
    # 关键变更:原本只支持 sm100,现在新增 sm103(Blackwell Ultra)
    if _device_sm in [100, 103] and self.weight.shape[0] == 256:
        # router gemm output float32
        logits = torch.empty(
            hidden_states.shape[0],
            self.weight.shape[0],
            device=hidden_states.device,
            dtype=torch.float32,
        )
        flashinfer_dsv3_router_gemm(logits, hidden_states, self.weight)
    else:
        logits = dsv3_router_gemm(
            hidden_states, self.weight, out_dtype=torch.float32
        )

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

风险极低。变更仅一行条件判断,且测试覆盖了多种 batch size 的正确性,精度与原有实现一致。未覆盖 sm103 下的性能退化场景,但 flashinfer 侧已提供兼容性保证。

影响范围小:仅影响 sm103 设备上 DeepSeek-V3/V2 路由 GEMM 的算子选择。用户无需手动配置,flashinfer 版本满足要求即可自动启用。性能提升依赖于 flashinfer 算子的优化效果。

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论