Prhub

#26665 [refactor] unify cuda-graph capture/replay across attention backends

原始 PR 作者 ch-wan 合并时间 2026-05-30 03:46 文件变更 19 提交数 20 评论 4 代码增减 +1066 / -1575

执行摘要

统一 Attention 后端 CUDA Graph capture/replay

重新落地被回滚的 PR #26134,并扩展覆盖额外 4 个后端。单 PR 取代之前堆叠的分支系列(#26144、#26159、#26160、#26162),避免链式依赖。PR 声明为纯重构,不改变计算路径与性能。

值得深入阅读,尤其是提取的 Pattern A/B 设计,可作为未来添加新注意力后端的模板。PR 提交颗粒度清晰,每条 commit 对应一个后端,易于 review。建议阅读 commits 中的详细消息(如 FlashMLABackend 的 q_head_mult 偏移技巧)。对于维护者,建议运行完整的注意力单元测试套件以确保无回归。

讨论亮点

review 中 chatgpt-codex-connector[bot] 指出两个潜在问题:

  • Ascend GDN 遗留:capture 调用 replay 时传入 seq_lens_cpu=None,但 _replay_metadata 无条件比较 seq_lens_cpu == self.get_cuda_graph_seq_len_fill_value(),可能引起异常。建议要么保持原有 capture 路径,要么传递 seq_lens.cpu()
  • NPU DLLM 遗漏:capture 路径不再初始化 seq_lens_cpu_list / seq_lens_list_cumsum,导致 forward_dllm 使用 None/ stale 长度。
    两个问题在 PR 合并前未见明确修复,建议关注后续补丁。

实现拆解

  1. 提取统一模式 - Pattern A:capture 先创建 metadata 对象并绑定预分配 buffer 切片,然后委托给 replay 填充运行时数据(如 seq_lens、page_table)。典型后端:FlashAttention(通过 _bind_metadata_buffers)、TRTLLM-MHA(通过 _build_cuda_graph_metadata)。
    - Pattern B:capture 与 replay 共享 buffer 创建逻辑,capture 额外调用一次 replay 以完成初始化。典型后端:FlashInfer(通过 _prepare_cuda_graph_metadata)。
  2. 逐后端应用重构 - commit 顺序:FlashMLABackend → TRTLLMMHABackend → FlashAttentionBackend → TRTLLMMLABackend → DualChunk → Mamba → Aiter → Lightning → AscendGDN → Ascend → DeepSeekSparse → DeepSeekV4 → DeepSeekV4HIPRadix → FlashInferMLA。每个后端按对应模式改造 init_forward_metadata_capture_cuda_graphinit_forward_metadata_replay_cuda_graph。关键文件示例:triton_backend.py 新增 _fill_kv_indptr_and_indices_update_decode_kv_buffers 等辅助方法;flashinfer_backend.py 提取 _create_decode_wrappers_create_prefill_wrappers
  3. 处理边界与冲突 - FlashAttention topk>1 target_verify 不能委托 replay,因 capture 时 dummy spec_info 缺少 positions/custom_mask,保留原 capture 路径。
    - DraftExtend 模式 replay 后需恢复 max_seq_len_q(bake 为常量)。
    - 与 SWA fix PR #26152 冲突,通过合并方案解决(在 triton_backend.py 中保留 invalidate_loc_cache 调用)。
  4. 测试与验证 - 新增 test/registered/attention/unittests/dense/test_tbo.py,构造 TboAttnBackend(primary=fa3, children=[fa3, fa3]) 链,直接调用 init_forward_metadata_capture_cuda_graph,验证无 KeyError: bs 异常。
    - 现有 accuracy 与 speed 测试通过(CI 绿色)。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/triton_backend.py 注意力层 modified 8.93
python/sglang/srt/layers/attention/flashinfer_backend.py 注意力层 modified 8.92
python/sglang/srt/layers/attention/flashattention_backend.py 注意力层 modified 8.22
python/sglang/srt/layers/attention/trtllm_mha_backend.py 注意力层 modified 8.02
test/registered/attention/unittests/dense/test_tbo.py 回归测试 added 6.0

关键符号

init_forward_metadata_capture_cuda_graph init_forward_metadata_replay_cuda_graph _build_cuda_graph_forward_metadata _bind_metadata_buffers _prepare_cuda_graph_metadata _create_decode_wrappers _create_prefill_wrappers _fill_kv_indptr_and_indices _update_decode_kv_buffers _update_target_verify_buffers _update_draft_extend_buffers update_sliding_window_buffer_cuda_graph _build_cuda_graph_metadata _init_cuda_graph_metadata

关键源码片段

python/sglang/srt/layers/attention/flashattention_backend.py core-logic

FlashAttention 后端通过 _bind_metadata_buffers 将原 250 行的 capture 函数缩减为约 20 行,是 Pattern A 的典型代表。

def _bind_metadata_buffers(
    self,
    bs: int,
    num_tokens: int,
    encoder_lens: Optional[torch.Tensor],
    forward_mode: ForwardMode,
    spec_info: Optional[SpecInput],
    device: torch.device,
) -> tuple:
    """Create FlashAttentionMetadata with pre-allocated buffer slice refs.    Assigns all buffer slice references but does NOT fill data values.
    Stores the new metadata object(s) in the appropriate lookup dicts.
    Returns (metadata, metadata_expand).
    """
    metadata = FlashAttentionMetadata()
    metadata_expand = FlashAttentionMetadata()
​
    if forward_mode.is_decode_or_idle():
        if spec_info is not None:
            if self.topk <= 1:
                # Draft Decode topk=1: 绑定预分配 buffer 的切片引用
                metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
                    "cache_seqlens"][:bs]
                metadata.cu_seqlens_q = self.decode_cuda_graph_metadata[
                    "cu_seqlens_q"][:bs + 1]
                metadata.cu_seqlens_k = self.decode_cuda_graph_metadata[
                    "cu_seqlens_k"][:bs + 1]
                metadata.page_table = self.decode_cuda_graph_metadata[
                    "page_table_draft_decode"][:bs, :]
                if self.use_sliding_window_kv_pool:
                    metadata.swa_page_table = self.decode_cuda_graph_metadata[
                        "swa_page_table"][:bs, :]
                self.decode_cuda_graph_metadata[bs] = metadata
            else:
                # topk>1 需要两个 metadata 对象
                # ...

评论区精华

Ascend GDN capture 传递 seq_lens_cpu=None 可能导致失败 正确性

reviewer 指出 Ascend GDN capture 调用 replay 时传入 seq_lens_cpu=None,但 _replay_metadata 无条件比较 seq_lens_cpu == self.get_cuda_graph_seq_len_fill_value(),可能引起比较异常。建议要么保持原有 capture 路径,要么传递 seq_lens.cpu()。

结论:未见到作者直接回复,PR 已合并,可能已在其他提交或后续修复中覆盖。 · 待处理

NPU DLLM capture 未初始化 seq_lens_cpu_list 等字段 正确性

reviewer 指出 NPU DLLM 模型在 capture 路径不再初始化 seq_lens_cpu_list/seq_lens_list_cumsum,导致 forward_dllm 使用 None/stale 长度,可能导致 npu_fused_xxx 传入错误数据。

结论:同上,未明确修复。 · 待处理

风险与影响

  1. 大量后端的统一重构可能导致某些特殊模式(如 FlashAttention topk>1 target_verify)被错误地委托给 replay,已在代码中显式跳过,但仍有遗漏风险。
  2. Ascend/NPU 后端的 capture 路径简化可能遗漏初始化字段(review 指出的两个问题),可能导致运行时错误。
  3. 统一模式依赖 seq_lens_cpu 参数传递,部分后端可能在 capture 时传递 None 引发比较异常。
  4. 由于是重新落地被回滚的 PR,需确保前次回滚的所有问题都已修复。

影响范围:使用 CUDA Graph 的所有注意力后端(约 16 个),删除重复代码约 1500 行,统一维护逻辑。用户无感知(纯重构),新后端开发可复用统一模式。系统稳定性依赖后续持续验证。

NPU / Ascend 边界初始化问题 TBO capture 修复依赖测试 统一模式可能遗漏特殊 forward mode

关联 Issue

#26152 fix(swa): eliminate spurious translate_loc_from_full_to_swa warning in BCG and CG paths

完整报告

参与讨论