Prhub

#25002 [spec_v2] Enable trtllm_mha draft-extend CUDA graph with v2 semantics

原始 PR 作者 YAMY1234 合并时间 2026-06-05 08:50 文件变更 4 提交数 12 评论 7 代码增减 +68 / -16

执行摘要

为 trtllm_mha 启用 spec_v2 draft-extend CUDA graph

V2 draft-extend 路径此前无法启用 trtllm_mha 的 CUDA graph,因为在白名单中缺失 TRTLLMHAAttnBackend。且旧 replay 逻辑会计算完整的 softmax+topk/reduce,但这些结果在 V2 worker 中不被使用(worker 会重新对选中行做 softmax+fast_topk),造成无效计算。

值得精读。本 PR 展示了在复杂推测解码路径中启用 CUDA graph 的完整思路:白名单控制、metadata 语法适配、合理精简 graph 内部计算量以避免浪费,以及对应的测试合约更新。对于理解 speculative v2、TRTLLM backend 以及 CUDA graph 的正确使用很有参考价值。

讨论亮点

Review 中 merrymercy 建议将 eagle_worker_v2.py 中孤立的 or isinstance(..., TRTLLMHAAttnBackend) 改为用元组统一判断,类似已有的其他 backend 写法。作者采纳并修改。

实现拆解

  1. 白名单扩展:在 eagle_worker_v2.pysupports_cuda_draft_extend_graph 条件中添加 TRTLLMHAAttnBackend,并按照 review 建议将孤立 isinstance 合并为元组判断。
  2. Metadata 分支调整:在 trtllm_mha_backend.py 中,将 .is_draft_extend() 调用改为 .is_draft_extend(include_v2=True),让 V2 也进入 draft-extend 的 metadata 构建分支。并在 _apply_cuda_graph_metadata 内新增 is_draft_extend_v2() 分支:V2 使用 spec_info.num_tokens_per_req 作为统一步长填充 cu_seqlens_qmax_seq_len_q,不再沿用 V1 的 num_accept_tokens 变长语义。
  3. Graph 输出精简:在 eagle_draft_extend_cuda_graph_runner.pyreplay 方法中,针对 DRAFT_EXTEND_V2 模式,跳过 topk_ptopk_index 的切片赋值,只保留 next_token_logitshidden_states 的输出(graph 内仍通过 torch.amax 锚定全 logits 以满足 CUDA graph 生命周期要求,但不产生 top-k 输出)。
  4. 测试适配:在 speculative_draft_extend_runner.py 中新增 _assert_draft_extend_v2_outputs_close 函数,仅比较 logits 和 hidden_states,不再断言 topk 字段;并在 dense/MLA 的两条 V2 测试用例中将其挂载为 assert_outputs_close
文件 模块 状态 重要度
python/sglang/srt/layers/attention/trtllm_mha_backend.py 注意力后端 modified 6.65
python/sglang/srt/speculative/eagle_worker_v2.py 推测解码 modified 6.04
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py CUDA Graph 运行器 modified 5.7
python/sglang/test/kits/attention_unittest/runner_modes/speculative_draft_extend_runner.py 测试 modified 5.35

关键符号

_assert_draft_extend_v2_outputs_close

关键源码片段

python/sglang/srt/layers/attention/trtllm_mha_backend.py core-logic

核心:在 `_build_cuda_graph_metadata` 和 `_apply_cuda_graph_metadata` 中扩展 `draft_extend` 条件以包含 V2,并新增 V2 专属的 metadata 填充逻辑。

# python/sglang/srt/layers/attention/trtllm_mha_backend.py(片段)# 在 _build_cuda_graph_metadata 中,将 draft_extend 分支条件扩展为 include_v2
elif forward_mode.is_draft_extend(include_v2=True): # ← 同时覆盖 V1 和 V2
    num_tokens_per_bs = num_tokens // bs
    metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][:bs]
    metadata.cu_seqlens_q = self.draft_extend_metadata["cu_seqlens_q"][:bs + 1]
    metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][:bs + 1]
    metadata.max_seq_len_q = num_tokens_per_bs
    metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
    # ... bind swa page table
    self.draft_extend_metadata[bs] = metadata# 在 _apply_cuda_graph_metadata 中,之前的条件也改为 include_v2
elif forward_mode.is_draft_extend(include_v2=True):
    metadata = self.draft_extend_metadata[bs]
    metadata.cache_seqlens_int32.copy_(seq_lens)
    metadata.max_seq_len_k = seq_lens_cpu.max().item()
    # ... cu_seqlens_k cumsum
​
    # V2 与 V1 分流:V2 使用 num_tokens_per_req 作为一致步长
    if forward_mode.is_draft_extend_v2():
        num_tokens_per_bs = spec_info.num_tokens_per_req
        if num_tokens_per_bs <= 0:
            # 捕获阶段使用合成输入,fallback 推断步长
            num_tokens_per_bs = int(
                spec_info.num_accept_tokens[:bs].max().item()
            )
        metadata.max_seq_len_q = num_tokens_per_bs
        # cu_seqlens_q 填充为等差数列:0, step, 2*step, ...
        metadata.cu_seqlens_q[1:].copy_(
            torch.arange(
                num_tokens_per_bs,
                bs * num_tokens_per_bs + 1,
                num_tokens_per_bs,
                dtype=torch.int32,
                device=metadata.cu_seqlens_q.device,
            )
        )
    else:
        # V1 分支不变:使用 num_accept_tokens 变长填充
        extend_lens = spec_info.num_accept_tokens[:bs]
        if spec_info.num_accept_tokens_cpu:
            metadata.max_seq_len_q = max(spec_info.num_accept_tokens_cpu)
        else:
            metadata.max_seq_len_q = 1
        metadata.cu_seqlens_q[1:].copy_(
            torch.cumsum(extend_lens, dim=0, dtype=torch.int32)
        )
python/sglang/srt/speculative/eagle_worker_v2.py dependency-wiring

入口:在此文件中导入 TRTLLMHAAttnBackend 并将其加入白名单,同时内部 `_draft_extend_for_decode` 添加注释说明 graph 输出语义。

# python/sglang/srt/speculative/eagle_worker_v2.py(片段)from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend # 新增导入# ... 在 init_cuda_graphs 中
supports_cuda_draft_extend_graph = (_is_cuda or _is_musa) and isinstance(
    self.draft_extend_attn_backend,
    (
        TritonAttnBackend,
        TRTLLMMLABackend,
        TRTLLMHAAttnBackend, # 新增
        TokenspeedMLABackend,
    ),
)# 在 _draft_extend_for_decode 中,graph 输出只锚定 logits,top-k 由 worker 后算
# The draft-extend graph only anchors full logits; selected-row topk is
# owned by the worker for both graph and eager paths.
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py core-logic

核心逻辑:在 replay 中对 V2 模式跳过 topk 复制,避免冗余计算。

# python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py(片段)# 在 replay 方法末尾,unpadding 部分
if unpadding_bs is not None:
    out_copy = out
    # 构造只含 logits 和 hidden_states 的输出,不拷贝 topk
    out = LogitsProcessorOutput(
        next_token_logits=out.next_token_logits[:unpadding_bs],
        hidden_states=out.hidden_states[:unpadding_bs],
    )
    # 对于 V2 模式,graph 内部已通过 torch.amax 锚定 logits 但未输出 topk
    if self.forward_mode != ForwardMode.DRAFT_EXTEND_V2:
        out.topk_p = out_copy.topk_p[:raw_bs]
        out.topk_index = out_copy.topk_index[:raw_bs]
return out

评论区精华

isinstance 调用风格 style

merrymercy 建议将新增的 `or isinstance(...)` 与已有用法统一为元组形式 `isinstance(self.draft_extend_attn_backend, (..., TRTLLMHAAttnBackend))`。

结论:作者接受并修改,简化代码。 · 已解决

风险与影响

  1. 兼容性:is_draft_extend(include_v2=True) 会同时匹配 V1 和 V2,需确认原 V1 的 draft_extend 路径在 metadata 构建和 replay 中行为不变(patch 中在 _apply_cuda_graph_metadata 内部用 is_draft_extend_v2() 分流,不影响 V1 逻辑)。
  2. 边界条件:V2 的 num_tokens_per_req 可能在 capture 阶段为 0(合成输入),代码中 fallback 到 num_accept_tokens[:bs].max().item(),若所有 accept_tokens 也为 0 可能导致异常,但实际 capture 输入设计保证了至少有一个 token。
  3. 测试覆盖:新增的 V2 断言不再校验 topk,如果未来修改了 worker 与 graph 的 top-k 交付契约,测试可能无法捕获回归。但当前设计明确将 top-k 计算后置到 worker,因此测试只验证 graph 实际锚定的输出是合理的。

对用户:当使用 trtllm_mha 作为 draft-extend attention backend 且开启 spec_v2 时,draft-extend 步骤将自动享受 CUDA graph 加速,同时消除之前 graph 中无用的 top-k 计算,提升推理吞吐。不影响已有 V1 行为或其它 backend 路径。
对系统:无新增配置项,启动时自动生效。需要 CUDA 环境且 trtllm_mha 可用。

路径兼容性 边界条件依赖捕获输入 测试契约变更

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论