# PR #26494 完整报告

- 仓库：`sgl-project/sglang`
- 标题：Remove DeepGEMM for indexer GEMM in piecewise NSA path
- 合并时间：2026-05-28 15:15
- 原文链接：http://prhub.com.cn/sgl-project/sglang/pull/26494

---

# 执行摘要

- 一句话：移除 NSA 分段路径中冗余的 DeepGEMM 依赖
- 推荐动作：该 PR 改动简洁且正面，值得合并。代码风格清晰，注释充分。建议阅读者关注 `torch.mm` 的 `out_dtype` 用法，这是 PyTorch 2.10 的新特性。

# 功能与动机

PR body 指出 DeepGEMM 调用是在不同 PR 合并时被意外重新引入的 ( 由 PR#23351 引入 )，且会增加预热时间。作者希望移除它来加速 warmup 并清理依赖。

# 实现拆解

1. 在 `python/sglang/srt/layers/attention/dsa/dsa_indexer.py` 的 `logits_head_gate_pcg` 函数中，将 `deep_gemm_wrapper.gemm_nt_bf16bf16f32` 调用替换为 `torch.mm(x, weight.t(), out_dtype=torch.float32)`。
2. 移除了对应的 DeepGEMM 导入语句 `from sglang.srt.layers.deep_gemm_wrapper import entrypoint as deep_gemm_wrapper`。
3. 精简了 GEMM 实现，不再需要显式创建空 tensor 再写入，而是直接利用 PyTorch 2.10+ 支持的 `out_dtype` 参数完成矩阵乘法并输出 float32 结果。

关键文件：
- `python/sglang/srt/layers/attention/dsa/dsa_indexer.py`（模块 注意力层；类别 source；类型 dependency-wiring；符号 logits_head_gate_pcg）: 核心变更文件：替换 indexer GEMM 的实现，移除 DeepGEMM 依赖，减少 warmup 时间。

关键符号：logits_head_gate_pcg

## 关键源码片段

### `python/sglang/srt/layers/attention/dsa/dsa_indexer.py`

核心变更文件：替换 indexer GEMM 的实现，移除 DeepGEMM 依赖，减少 warmup 时间。

```python
# 文件 : python/sglang/srt/layers/attention/dsa/dsa_indexer.py
# 此函数实现 indexer 的 GEMM 逻辑，在 piecewise NSA 路径中被调用
# 变更前使用了 DeepGEMM wrapper，增加了启动预热时间
# 变更后改用 PyTorch 原生 torch.mm ( 需要 PyTorch >= 2.10 支持 out_dtype)

@register_custom_op(fake_impl=_logits_head_gate_pcg_fake_impl)
def logits_head_gate_pcg(
    x: torch.Tensor,
    weight: torch.Tensor,
    n_heads_inv_sqrt: float,
    softmax_scale: float,
    q_scale: torch.Tensor,
) -> torch.Tensor:
    # 直接使用 torch.mm 进行矩阵乘法，并指定输出为 float32
    # 避免了之前 deep_gemm_wrapper 的预热开销
    out = torch.mm(x, weight.t(), out_dtype=torch.float32)
    weights = out * n_heads_inv_sqrt
    weights = weights.unsqueeze(-1) * q_scale * softmax_scale
    return weights

```

# 评论区精华

gemini-code-assist[bot] 指出 `torch.mm` 不支持 `out_dtype` 参数，建议改用 `.to(torch.float32)`。但作者 @b8zhong 回应称 PyTorch 2.10 已支持该参数，确认正确性。MR 随后被 reviewer Fridge003 批准合并。

- torch.mm out_dtype 参数兼容性 (correctness): 开发团队确认 torch.mm 的 out_dtype 在 PyTorch 2.10 可用，代码正确。

# 风险与影响

- 风险：风险极低：改动仅涉及 indexer GEMM 的单次矩阵乘法路径，且已有独立的精度验证（PR#23856）。如果运行环境中 PyTorch 版本低于 2.10，`out_dtype` 参数会报错，但 sglang 通常要求较新的 PyTorch 版本，因此不太可能发生。建议在 CI 中确认版本兼容性。
- 影响：直接影响：减少分段 NSA 路径的 warmup 时间，消除 DeepGEMM 的依赖加载开销。对推理精度无影响。对用户的唯一影响是减少了框架启动时间。
- 风险标记：依赖新 PyTorch 特性

# 关联脉络

- PR #23351 ( 原始引入 DeepGEMM 的 PR，未提供具体标题 ): 该 PR 意外将 DeepGEMM 调用重新引入 piecewise 路径，本 PR 旨在移除它。
- PR #23856 ( 精度验证 PR，未提供具体标题 ): 本 PR 所依赖的精度测试在此 PR 中已经完成验证。