Prhub

#24635 Fix stuck when enabling MTP on DSA models

原始 PR 作者 Fridge003 合并时间 2026-05-08 08:06 文件变更 4 提交数 3 评论 7 代码增减 +44 / -18

执行摘要

修复 DSA 模型启用 MTP 时的死锁问题

根据 Issue #24571,深度求索 V3.2 及 GLM-5 等 DSA 模型在启用 MTP 后出现 hang 住的问题,CI 测试也被禁用。该 hang 由 deep_gemm 路径中 fp8_paged_mqa_logits 的 tensor 布局不匹配和 draft extend 模式未覆盖 v2 引起。

此 PR 修复了高优先级 bug,改动集中、逻辑清晰,CI 已全部通过。建议尽快合并并回传到相关发布分支。值得关注的设计决策包括:frozen dataclass 在 CUDA graph replay 中的赋值模式,以及 _to_2d_context_lens 的布局规范方法。

讨论亮点

无 reviewer 评论,作者自行触发 CI 并通过后合并。CI 显示测试 test_dsa_models_mtp.py(8-gpu-h200)、test_deepseek_v32_fp4_mtp_4gpu.py(4-gpu-b200)、test_deepseek_v32_cp_single_node.py(8-gpu-h200-deepep)均通过。

实现拆解

  1. 规范 tensor 布局避免死锁:在 python/sglang/srt/layers/attention/nsa_backend.py_to_2d_context_lens 函数中,将输入 seqlens 强制统一为 (N_total, 1) 形状,消除二义性使 deep_gemm.get_paged_mqa_logits_metadata 不再死锁。当输入为 2D 且列数不为 1 时先扁平化再 reshape,并保证 contiguous。
  2. 扩展 draft extend 条件到 v2:在 init_forward_metadatainit_forward_metadata_capture_cuda_graphinit_forward_metadata_replay_cuda_graph 三处,将 is_draft_extend() 改为 is_draft_extend(include_v2=True),确保 MTP v2 模式也能进入 deep_gemm 的 paged MQA logits 分支,避免因未走该分支导致数据不一致。
  3. 修复冻结 dataclass 赋值错误:在 init_forward_metadata_replay_cuda_graph 中,将 metadata.paged_mqa_schedule_metadata = new_schedule 改为 object.__setattr__(metadata, "paged_mqa_schedule_metadata", new_schedule),因为 NSAMetadata 是 frozen dataclass 直接赋值会抛 FrozenInstanceError,原代码在捕获异常后静默忽略实为隐晦 bug。
  4. 重新启用 CI 测试:移除了三个测试文件中的 disabled="Disabled due to #24268. Should be fixed soon." 行,这些测试覆盖了 DSA 模型 MTP、DeepSeek V3.2 CP 单节点、以及 FP4 量化 MTP 场景,验证修复有效性。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/nsa_backend.py 注意力层 modified 6.95
test/registered/8-gpu-models/test_dsa_models_mtp.py 测试 modified 3.09
test/registered/cp/test_deepseek_v32_cp_single_node.py 测试 modified 2.88
test/registered/quant/test_deepseek_v32_fp4_mtp_4gpu.py 测试 modified 2.88

关键符号

_to_2d_context_lens init_forward_metadata init_forward_metadata_capture_cuda_graph init_forward_metadata_replay_cuda_graph

关键源码片段

python/sglang/srt/layers/attention/nsa_backend.py core-logic

核心修复文件:修改了 `_to_2d_context_lens` 避免死锁,扩展了 draft extend 条件到 v2,并修复了 frozen dataclass 在 CUDA graph replay 中的错误赋值。

def _to_2d_context_lens(seqlens_32: torch.Tensor, batch_size: int) -> torch.Tensor:
    # Always normalize to (N_total, 1) layout, to avoid deadlock at deep_gemm.fp8_paged_mqa_logits
    if seqlens_32.dim() == 2:
        if seqlens_32.size(1) == 1:
            # Already (batch, 1) — done
            return seqlens_32
        # Fall through and re-flatten if the caller already gave us a (bs, next_n)
        # view — we want (N_total, 1) regardless.
        seqlens_32 = seqlens_32.reshape(-1)
    return seqlens_32.contiguous().view(-1, 1)# 调用处(示例来自 init_forward_metadata):
if is_cuda() and (
    forward_batch.forward_mode.is_decode_or_idle()
    or forward_batch.forward_mode.is_target_verify()
    or forward_batch.forward_mode.is_draft_extend(include_v2=True) # 关键修复:从 include_v2=False 改为 True
):
    try:
        import deep_gemm
        # ...
        paged_mqa_schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
            seqlens_32_2d, 64, deep_gemm.get_num_sms()
        )
    except (ImportError, ModuleNotFoundError):
        paged_mqa_schedule_metadata = None

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

核心风险在于 _to_2d_context_lens 的 reshape 行为变更:原先如果输入为 2D 则直接返回,现在会检查列数并可能扁平化再 reshape,这可能改变下游消费该 tensor 的代码预期。但注释说明这是为了统一布局,且下游消费函数 deep_gemm.get_paged_mqa_logits_metadata 预期一个 (batch, 1) 形状,所以应该安全。另外 object.__setattr__ 绕过了 frozen dataclass 的不可变性,可能被其他代码误用,但这是已有模式(capture 中已有直接赋值),replay 中也用相同模式保持一致性。整体风险较低,但涉及 CUDA graph 捕获和重放路径,需防止回退。

影响范围:修复了 DSA 模型(DeepSeek V3.2、GLM-5)在启用 MTP(含 v2)时的 hang 问题;重新激活了三个关键 CI 测试,覆盖多 GPU 场景(TP8 DP8、FP4 4GPU、CP 8GPU)。对使用 EAGLE 推测解码和 NSA attention 的用户是关键修复;不影响没有启用 MTP 的配置。

冻结 dataclass 绕过 核心路径变更 CUDA graph 捕获重放

关联 Issue

#24571 [Bug] MTP causes hang on DSA models after rebasing deep_gemm

完整报告

参与讨论