Prhub

#27138 Revert "Support NextN = 2/4 in DSV32"

原始 PR 作者 ch-wan 合并时间 2026-06-03 16:29 文件变更 2 提交数 1 评论 3 代码增减 +92 / -136

执行摘要

回退 DSV32 中 NextN=2/4 支持,恢复 next_n=1

原 PR #24870 在 DeepGEMM paged MQA logits 中增加了对 next_n≥2 的原生支持(SM100 专用路径),以提升 speculative decoding 性能。但 CI 测试(pr-test 和 pr-test-extra)均显示失败(红色叉),且自动化代码审查指出 deep_gemm 导入异常可能引发混淆的 AttributeError。回退以恢复稳定版本,待后续修复后重新引入。

值得精读,尤其是对 speculative decoding 和 DeepGEMM 调度感兴趣的同学。回退体现了在引入复杂性能优化时对稳定性的权衡,同时自动化 code review 提示了 import 错误处理的细节可作参考。建议关注后续关联 PR 中如何更稳健地重新实现 NextN 支持。

讨论亮点
  • 自动化代码审查(gemini-code-assist[bot]):在 dsa_indexer.py 第 612 行附近提到,当 _is_cuda 为 True 但 deep_gemm 未安装时(因为 dsa_backend.py 捕获了 ImportError),访问 deep_gemm.get_paged_mqa_logits_metadata 会抛出混淆的 AttributeError,建议显式检查 deep_gemm 是否为 ImportError 实例。该建议未在本次回退中处理,属于遗留风险。
  • 重新运行测试:作者 ch-wan 发出 /rerun-test registered/attention/unittests/dsa/test_dsa.py,结果在 4-gpu-b200 上失败,在 1-gpu-h100 上通过,可能进一步促成了回退。

实现拆解

  1. dsa_backend.py:删除 _build_paged_mqa_schedule_2d_ctx_lens 函数;从 DSAMetadataDSAIndexerMetadata 中移除 paged_mqa_ctx_lens_2d 字段;移除对 is_sm100_supported 的导入和 deep_gemm 的早期导入;在 init_forward_metadata 中移除 native 2D ctx lens 计算,统一使用原有的 per-token 或 expanded 布局。
  2. dsa_indexer.py:简化 _get_topk_paged 方法,移除 native next_n 路径(use_dg_native 分支),将 query 始终 unsqueeze(1) 确保 next_n=1;移除 paged_mqa_ctx_lens_2d 的获取和使用,直接根据 seqlens_32 维度决定 2D layout。
  3. 无测试配套变更:本次仅修改源码文件,未调整测试用例。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/dsa_backend.py DSA 后端 modified 7.85
python/sglang/srt/layers/attention/dsa/dsa_indexer.py DSA 索引器 modified 6.73

关键符号

_build_paged_mqa_schedule_2d_ctx_lens init_forward_metadata _get_topk_paged

关键源码片段

python/sglang/srt/layers/attention/dsa_backend.py core-logic

核心变更文件,删除了 NextN 调度的核心函数和字段,调整了 import 和 init_forward_metadata 逻辑。

# python/sglang/srt/layers/attention/dsa_backend.py
# 回退后:DSAMetadata 中移除了 paged_mqa_ctx_lens_2d 字段
class DSAMetadata:
    # ... 其他字段 ...
    paged_mqa_schedule_metadata: Optional[torch.Tensor] = None
    # paged_mqa_ctx_lens_2d 被删除,不再维护 2D context length
​
    # init_forward_metadata 中不再调用 _build_paged_mqa_schedule_2d_ctx_lens
    # 直接使用原有的 per-token 布局
    if is_cuda() and (
        forward_batch.forward_mode.is_decode_or_idle()
        or forward_batch.forward_mode.is_target_verify()
        or forward_batch.forward_mode.is_draft_extend(include_v2=True)
    ):
        paged_mqa_schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
            # 使用 per-token 的 2D context lens
            _to_2d_context_lens(cache_seqlens_int32, batch_size),
            blocksize, self.sm_count
        )
python/sglang/srt/layers/attention/dsa/dsa_indexer.py core-logic

indexer 中移除了 native next_n 路径,简化 query 处理逻辑,总是 unsqueeze next_n 为 1。

# python/sglang/srt/layers/attention/dsa/dsa_indexer.py
# _get_topk_paged 方法中删除了 native next_n 分支
# 回退后:
assert len(q_fp8.shape) == 3
q_fp8 = q_fp8.unsqueeze(1) # 强制 next_n = 1,适配 DeepGEMM API
# ...
q_offset = sum(metadata.get_dsa_extend_len_cpu())
if _is_hip:
    # AMD 路径保持不变
    # ...
else:
    logits = deep_gemm.fp8_paged_mqa_logits(
        q_fp8[:q_offset],
        kv_cache_fp8,
        weights[:q_offset],
        seqlens_32_2d,
        block_tables,
        schedule_metadata,
        max_seq_len,
        clean_logits=False,
    )
# 不再有 use_dg_native 分支,始终使用 expanded(unsqueezed)格式

评论区精华

ImportError 处理风险 安全

gemini-code-assist[bot] 提出,当 _is_cuda 为 True 但 deep_gemm 未安装时,dsa_indexer.py 中访问 deep_gemm.get_paged_mqa_logits_metadata 会抛出 AttributeError,因为 deep_gemm 可能是一个 ImportError 对象(由于 dsa_backend.py 中的 import 被异常捕获)。建议添加显式检查。

结论:未在本次 PR 中处理,属于遗留风险。回退本身未改变 deep_gemm 的导入方式,但移除了早期 import 语句(原 PR 中 import deep_gemm 在 if is_cuda() 块内,回退后删除该块),可能导致该路径更易触发。 · unresolved

CI 测试失败 question

作者 ch-wan 发起 /rerun-test 针对 dsa/test_dsa.py,结果在 4-gpu-b200 上失败,1-gpu-h100 上通过,暗示 blackwell 多 GPU 配置下存在兼容性问题。

结论:测试失败直接触发了本次回退。 · 已解决

风险与影响

  • 回退本身风险低:回退到 next_n=1 的稳定路径,消除了原 PR 在非 SM100 芯片上的兼容性风险。
  • 性能回退:丢失了 SM100 上 next_n≥2 的性能优化(原 PR 展示 GPQA 准确率提升约 1-2%),影响 speculative decoding 吞吐。
  • 遗留风险dsa_indexer.py 中未处理 deep_gemm 未安装时的优雅降级,若 _is_cuda 为 True 而 deep_gemm 缺失,仍会因早期导入被移除而引发 AttributeError(需待用户明确安装或补充检查)。
  • 用户影响:使用 SM100(Blackwell)芯片进行 DeepSeek 系列模型推理且依赖 NextN>1 speculative decoding 的用户,将无法获得原 PR 带来的加速效果;其他用户无影响。
  • 系统影响:消除因 native 2D ctx lens 导致的不确定性,系统稳定性恢复。
  • 团队影响:需要重新评估 NextN 支持的实现方案,特别是跨平台兼容性和异常处理。
部分性能回退 ImportError 隐患未修复 核心注意力路径变更 缺少测试覆盖

关联 Issue

#24870 Support NextN = 2/4 in DSV32

完整报告

参与讨论