Prhub

#26494 Remove DeepGEMM for indexer GEMM in piecewise NSA path

原始 PR 作者 b8zhong 合并时间 2026-05-28 15:15 文件变更 1 提交数 1 评论 3 代码增减 +1 / -6

执行摘要

移除 NSA 分段路径中冗余的 DeepGEMM 依赖

PR body 指出 DeepGEMM 调用是在不同 PR 合并时被意外重新引入的 (由 PR#23351 引入),且会增加预热时间。作者希望移除它来加速 warmup 并清理依赖。

该 PR 改动简洁且正面,值得合并。代码风格清晰,注释充分。建议阅读者关注 torch.mmout_dtype 用法,这是 PyTorch 2.10 的新特性。

讨论亮点

gemini-code-assist[bot] 指出 torch.mm 不支持 out_dtype 参数,建议改用 .to(torch.float32)。但作者 @b8zhong 回应称 PyTorch 2.10 已支持该参数,确认正确性。MR 随后被 reviewer Fridge003 批准合并。

实现拆解

  1. python/sglang/srt/layers/attention/dsa/dsa_indexer.pylogits_head_gate_pcg 函数中,将 deep_gemm_wrapper.gemm_nt_bf16bf16f32 调用替换为 torch.mm(x, weight.t(), out_dtype=torch.float32)
  2. 移除了对应的 DeepGEMM 导入语句 from sglang.srt.layers.deep_gemm_wrapper import entrypoint as deep_gemm_wrapper
  3. 精简了 GEMM 实现,不再需要显式创建空 tensor 再写入,而是直接利用 PyTorch 2.10+ 支持的 out_dtype 参数完成矩阵乘法并输出 float32 结果。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/dsa/dsa_indexer.py 注意力层 modified 5.4

关键符号

logits_head_gate_pcg

关键源码片段

python/sglang/srt/layers/attention/dsa/dsa_indexer.py dependency-wiring

核心变更文件:替换 indexer GEMM 的实现,移除 DeepGEMM 依赖,减少 warmup 时间。

# 文件 : python/sglang/srt/layers/attention/dsa/dsa_indexer.py
# 此函数实现 indexer 的 GEMM 逻辑,在 piecewise NSA 路径中被调用
# 变更前使用了 DeepGEMM wrapper,增加了启动预热时间
# 变更后改用 PyTorch 原生 torch.mm ( 需要 PyTorch >= 2.10 支持 out_dtype)@register_custom_op(fake_impl=_logits_head_gate_pcg_fake_impl)
def logits_head_gate_pcg(
    x: torch.Tensor,
    weight: torch.Tensor,
    n_heads_inv_sqrt: float,
    softmax_scale: float,
    q_scale: torch.Tensor,
) -> torch.Tensor:
    # 直接使用 torch.mm 进行矩阵乘法,并指定输出为 float32
    # 避免了之前 deep_gemm_wrapper 的预热开销
    out = torch.mm(x, weight.t(), out_dtype=torch.float32)
    weights = out * n_heads_inv_sqrt
    weights = weights.unsqueeze(-1) * q_scale * softmax_scale
    return weights

评论区精华

torch.mm out_dtype 参数兼容性 正确性

gemini-code-assist[bot] 指出标准 torch.mm 不支持 out_dtype 参数,会引发 TypeError,建议改为 .to(torch.float32)。@b8zhong 回复称该参数在 PyTorch 2.10 中已支持。

结论:开发团队确认 torch.mm 的 out_dtype 在 PyTorch 2.10 可用,代码正确。 · 已解决

风险与影响

风险极低:改动仅涉及 indexer GEMM 的单次矩阵乘法路径,且已有独立的精度验证(PR#23856)。如果运行环境中 PyTorch 版本低于 2.10,out_dtype 参数会报错,但 sglang 通常要求较新的 PyTorch 版本,因此不太可能发生。建议在 CI 中确认版本兼容性。

直接影响:减少分段 NSA 路径的 warmup 时间,消除 DeepGEMM 的依赖加载开销。对推理精度无影响。对用户的唯一影响是减少了框架启动时间。

依赖新 PyTorch 特性

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论