Prhub

#24870 Support NextN = 2/4 in DSV32

原始 PR 作者 b8zhong 合并时间 2026-06-03 10:06 文件变更 2 提交数 3 评论 5 代码增减 +136 / -92

执行摘要

支持 DSV32 中 NextN=2/4 利用 deep-gemm 原生路径提升 MTP 性能

DeepGEMM 在 PR #316 中去掉了 multicast,使得任意 next_n ≥ 1 可在 SM100 上执行。此 PR 旨在利用该能力,消除原先仅支持 next_n=1 的限制,提升 MTP(Multi-Token Prediction)场景下的解码吞吐。

建议仔细阅读 _build_paged_mqa_schedule_2d_ctx_lens_get_topk_paged 中的条件判断,理解原生路径与回退路径的设计取舍。同时关注后续 revert 或修复 PR 中对测试失败的处理。

讨论亮点

主要的讨论是合并后 ch-wan 发现 main 分支上新增的 DSA 测试失败,怀疑该 PR 引入回归。维护者 Fridge003 触发了额外 CI 测试,但尚未确认根因。此担忧成为后续 revert 的直接原因。

实现拆解

  1. 元数据扩展:在 DSAMetadataDSAIndexerMetadata 中新增 paged_mqa_ctx_lens_2d 字段,用于缓存预计算的上下文长度二维张量,避免每层重复构建。
  2. 调度张量构建:新增 _build_paged_mqa_schedule_2d_ctx_lens 方法,根据 forward mode 和 SM 版本决定是否输出 [B, next_n] 形状(原生路径)或 [B*next_n, 1] 形状(回退路径)。当满足 target_verify、next_n≥2 且为 SM100 时输出原生形状;否则沿用原有 per-token 布局。
  3. 索引器适配:在 _get_topk_paged 中,通过检查 paged_mqa_ctx_lens_2d 的形状决定是否启用原生路径(use_dg_native)。若启用,则直接以 q_fp8.view(B, next_n, H, D) 调用 deep_gemm.fp8_paged_mqa_logits,并相应调整 block_tables[::next_n] 对齐行主序;否则回退到原有 q_fp8.unsqueeze(1) 展开调用。
  4. 导入调整:条件导入 deep_gemm 并新增 is_sm100_supported 工具函数判断。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/dsa_backend.py DSA 后端 modified 7.85
python/sglang/srt/layers/attention/dsa/dsa_indexer.py 稀疏索引器 modified 6.73

关键符号

_build_paged_mqa_schedule_2d_ctx_lens init_forward_metadata _get_topk_paged

关键源码片段

python/sglang/srt/layers/attention/dsa_backend.py dependency-wiring

新增 _build_paged_mqa_schedule_2d_ctx_lens 方法,扩展元数据结构,串联 schedule 预计算与原生路径选取。

# dsa_backend.py 关键变更:新增字段与方法用于预计算二维上下文长度class DSAMetadata:
    # ... 其他字段
    paged_mqa_schedule_metadata: Optional[torch.Tensor] = None
    # 2D context_lens 用于构建 schedule;索引器可复用,避免每层重复广播
    paged_mqa_ctx_lens_2d: Optional[torch.Tensor] = Noneclass DSAIndexerMetadata(BaseIndexerMetadata):
    # ...
    paged_mqa_ctx_lens_2d: Optional[torch.Tensor] = None# 在 DSA 后端中新增的方法
class DeepseekSparseAttnBackend:
    # ...
    def _build_paged_mqa_schedule_2d_ctx_lens(
        self,
        forward_mode: ForwardMode,
        cache_seqlens_int32: torch.Tensor,
        seqlens_expanded: torch.Tensor,
        batch_size: int,
    ) -> torch.Tensor:
        # target_verify 且 next_n >= 2 且为 SM100 时,使用原生 [B, next_n] 形状
        next_n = self.speculative_num_draft_tokens
        if (
            forward_mode.is_target_verify()
            and next_n
            and next_n >= 2
            and is_sm100_supported()
        ):
            return cache_seqlens_int32.view(-1, 1).expand(-1, next_n).contiguous()
        if forward_mode.is_target_verify() or forward_mode.is_draft_extend(
            include_v2=True
        ):
            return _to_2d_context_lens(seqlens_expanded, batch_size)
        return _to_2d_context_lens(cache_seqlens_int32, batch_size)# 在 init_forward_metadata 中调用新方法构建 ctx_lens_2d
if is_cuda() and can_build_schedule:
    # 原有 schedule metadata 构建之前
    ctx_lens_2d = self._build_paged_mqa_schedule_2d_ctx_lens(
        forward_batch.forward_mode,
        cache_seqlens_int32,
        seqlens_expanded if hasattr else None,
        batch_size,
    )
    metadata.paged_mqa_ctx_lens_2d = ctx_lens_2d
python/sglang/srt/layers/attention/dsa/dsa_indexer.py core-logic

核心逻辑:根据 ctx_lens_2d 形状选择原生路径或回退路径,优化 MQA logits 计算。

# dsa_indexer.py _get_topk_paged 方法中的关键分支
# 使用预计算的 ctx_2d 形状作为选择依据
ctx_2d = getattr(metadata, "paged_mqa_ctx_lens_2d", None)
use_dg_native = (
    _is_cuda
    and forward_batch.forward_mode.is_target_verify()
    and next_n >= 2
    and ctx_2d is not None
    and ctx_2d.shape == (B, next_n)
)if use_dg_native:
    # 原生路径:直接使用 [B, next_n, H, D] 形状,block_tables 去重
    logits = deep_gemm.fp8_paged_mqa_logits(
        q_fp8[:q_offset].view(B, next_n, q_fp8.shape[1], q_fp8.shape[2]), # 不 unsqueeze
        kv_cache_fp8,
        weights[:q_offset],
        seqlens_32_2d, # 已经来自 ctx_2d
        block_tables[::next_n], # 对应 B 个 block
        schedule_metadata,
        max_seq_len,
        clean_logits=False,
    )
else:
    # 回退路径:保持原有 unsqueeze(1) 展开
    q_fp8 = q_fp8.unsqueeze(1)
    logits = deep_gemm.fp8_paged_mqa_logits(
        q_fp8[:q_offset],
        kv_cache_fp8,
        weights[:q_offset],
        seqlens_32_2d,
        block_tables,
        schedule_metadata,
        max_seq_len,
        clean_logits=False,
    )

评论区精华

DSA 测试回归怀疑 正确性

ch-wan: 'I recently added some attention unittests, and the main branch failed to pass one DSA test after this PR was merged. I'm checking if reverting this PR recovers the test.'

结论:未解决,最终该 PR 被 revert (PR #27138)。 · 已解决

风险与影响

  1. 正确性风险:原生路径仅通过 ctx_2d.shape 作隐式条件,若 target_verify 前后 batch_size 或 next_n 计算偏差可能导致错误分支。
  2. 兼容性风险:SM90 上 next_n=2 仅添加 TODO,未实际验证,可能触发未定义行为。
  3. 回归风险:索引器中 block_tables[::next_n] 要求 stride 兼容,若 page_table 布局不符可能导致偏移错误。
  4. 测试覆盖缺失:本次变更未新增针对性测试,仅依赖已有 CI(后被发现主分支 DSA 测试失败)。

对使用 DeepSeek V3.2 NVFP4 且启用 MTP(speculative EAGLE)的用户,在 SM100 上可获得 ~7% E2E decode TPS 提升;SM90 用户不受影响(回退到原路径)。但可能引入 DSA 测试回归,需要紧急修复或回退。

SM100 原生路径未在 SM90 验证 缺少针对性测试覆盖 已确认导致 DSA 测试回归 原生路径依赖隐式形状判断

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论