Prhub

#26134 [refactor] unify cuda-graph capture/replay across attention backends

原始 PR 作者 ch-wan 合并时间 2026-05-23 09:51 文件变更 5 提交数 3 评论 7 代码增减 +455 / -609

执行摘要

统一 5 个 attention 后端的 CUDA 图捕获 / 重放逻辑,消除重复代码并修复潜在 bug

PR描述指出init_forward_metadata_capture_cuda_graph和init_forward_metadata_replay_cuda_graph在每个attention后端中共享大量逻辑但各自维护独立副本,导致实际出现了分歧(如WaveAttnBackend跳过get_num_kv_splits,CutlassMLABackend含无用断言,FlashInferAttnBackend存在约240行重复的模式分派代码)。统一逻辑可消除这些隐患,降低维护成本。

值得精读。此PR展示了大型重构中消除重复的经典手法:提取公共方法、capture委托replay、工厂方法封装。对从事推理加速、CUDA图优化、attention模块开发的工程师有重要参考价值。建议在下次设计新的attention后端时,直接参考此PR总结的公共接口。

讨论亮点

本PR在合并前无 reviewer 技术讨论。作者在描述中详细说明了每个后端的差异化修改,并通过在评论中执行 /rerun-test 手动触发多项 CI 测试(test_basic_sanity、test_mla_flashinfer、test_eagle_infer 等),所有测试结果均通过。此外,作者在 GB300 上对 CutlassMLABackend 和 WaveAttnBackend 进行了完整的精度与功能验证,确保重构后无回归。

实现拆解

  1. 提取公共辅助函数:在 TritonBackend 中新增 _fill_kv_indptr_and_indices_update_decode_kv_buffers_update_target_verify_buffers_update_draft_extend_buffers_build_cuda_graph_forward_metadata 等方法,将分散在 capturereplay 中的重复缓冲填充逻辑集中到带文档字符串的独立方法中,便于后续后端复用。

  2. 采用 capture 委托 replay 模式(Pattern A):对 WaveAttnBackendCutlassMLABackendcapture 在非多步推测路径下直接调用 replay 完成缓冲写入,再通过 _build_cuda_graph_forward_metadata 冻结 ForwardMetadata。此举确保 capturereplay 使用完全一致的缓冲设置,修复了 WaveAttnBackendget_num_kv_splits 在 capture 阶段缺失的问题。

  3. 重构 FlashInferAttnBackend 为 prepare+replay 模式(Pattern B):提取 _create_decode_wrappers_create_prefill_wrappers 工厂方法负责包装器的构建(参数集中管理),新增 _prepare_cuda_graph_metadata 统一处理模式分派。capture 简化为调用 _prepare_cuda_graph_metadata 后调用 replay,不再包含分发分支。同时将 replay 方法中 is_target_verifyis_draft_extend 两条完全相同逻辑的分支合并。

  4. 合并 FlashInferMLAAttnBackend 的重复分支:将该后端 capturereplay 方法中 is_target_verifyis_draft_extend 两个完全相同的分支分别合并为一个条件,减少代码量约38行。

  5. 适配 CutlassMLABackend 并移除无用断言:采用 Pattern A,将 capture 委托给 replay,同时移除了 replay 中从未读取的 assert seq_lens_cpu is not None,使代码更简洁。

文件 模块 状态 重要度
python/sglang/srt/layers/attention/triton_backend.py 注意力后端 modified 8.69
python/sglang/srt/layers/attention/flashinfer_backend.py 注意力后端 modified 8.68
python/sglang/srt/layers/attention/wave_backend.py 注意力后端 modified 7.17
python/sglang/srt/layers/attention/flashinfer_mla_backend.py 注意力后端 modified 6.0
python/sglang/srt/layers/attention/cutlass_mla_backend.py 注意力后端 modified 5.97

关键符号

_fill_kv_indptr_and_indices _update_decode_kv_buffers _update_target_verify_buffers _update_draft_extend_buffers _build_cuda_graph_forward_metadata update_sliding_window_buffer_cuda_graph init_forward_metadata_capture_cuda_graph _create_decode_wrappers _create_prefill_wrappers _prepare_cuda_graph_metadata

关键源码片段

python/sglang/srt/layers/attention/triton_backend.py core-logic

最核心的重构目标,新增 5 个辅助方法将 capture/replay 中的重复缓冲逻辑提取为独立函数,为其他后端提供复用模板。

# triton_backend.py - 新提取的缓冲填充辅助方法def _fill_kv_indptr_and_indices(
    self,
    bs: int,
    seq_lens: torch.Tensor,
    req_pool_indices: torch.Tensor,
    kv_indices: torch.Tensor,
) -> torch.Tensor:
    """填充 kv_indptr 和 kv_indices 的通用方法,被 capture/replay 共用"""
    kv_indptr = self.kv_indptr[: bs + 1]
    kv_indptr[1:] = torch.cumsum(seq_lens, dim=0)
    # 使用 Triton 内核从 req_pool_indices 构建 flat kv_indices
    create_flashinfer_kv_indices_triton[(bs,)](
        self.req_to_token,
        req_pool_indices,
        seq_lens,
        kv_indptr,
        None,
        kv_indices,
        self.req_to_token.stride(0),
    )
    return kv_indptrdef _update_decode_kv_buffers(
    self,
    bs: int,
    seq_lens: torch.Tensor,
    req_pool_indices: torch.Tensor,
):
    """填充 decode/idle 模式下 CUDA 图所需的 KV 缓冲(含滑动窗口)"""
    seq_lens = seq_lens[:bs]
    req_pool_indices = req_pool_indices[:bs]
    kv_indptr = self._fill_kv_indptr_and_indices(
        bs, seq_lens, req_pool_indices, self.cuda_graph_kv_indices
    )
    window_kv_indptr = self.window_kv_indptr
    window_kv_lens = None
    if self.sliding_window_size is not None and self.sliding_window_size > 0:
        window_kv_indptr, _, window_kv_lens, _ = update_sliding_window_buffer(
            self.window_kv_indptr,
            self.req_to_token,
            self.sliding_window_size,
            seq_lens,
            req_pool_indices,
            bs,
            token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
            window_kv_indices=self.cuda_graph_window_kv_indices,
        )
    return kv_indptr, window_kv_indptr, window_kv_lens
python/sglang/srt/layers/attention/wave_backend.py core-logic

采用 Pattern A,新增 _build_cuda_graph_forward_metadata 方法,capture 委托 replay 后调用此方法冻结元数据,修复 get_num_kv_splits 缺失 bug。

# wave_backend.py - 从预分配缓冲构造 ForwardMetadata,capture 委托 replaydef _build_cuda_graph_forward_metadata(
    self,
    bs: int,
    forward_mode: ForwardMode,
    spec_info: Optional[SpecInput],
) -> ForwardMetadata:
    """从 CUDA 图的预分配缓冲中组装 ForwardMetadata(不执行任何核心计算)"""
    if forward_mode.is_decode_or_idle():
        return ForwardMetadata(
            attn_logits=self.cuda_graph_attn_logits,
            attn_lse=self.cuda_graph_attn_lse,
            max_extend_len=None,
            num_kv_splits=self.cuda_graph_num_kv_splits,
            kv_indptr=self.kv_indptr[: bs + 1],
            kv_indices=self.cuda_graph_kv_indices,
            qo_indptr=None,
            custom_mask=None,
            mask_indptr=None,
        )
    elif forward_mode.is_target_verify():
        return ForwardMetadata(
            attn_logits=None,
            attn_lse=None,
            max_extend_len=self.num_draft_tokens,
            num_kv_splits=None,
            kv_indptr=self.kv_indptr[: bs + 1],
            kv_indices=self.cuda_graph_kv_indices,
            qo_indptr=self.qo_indptr[: bs + 1],
            custom_mask=self.cuda_graph_custom_mask,
            mask_indptr=self.mask_indptr[: bs + 1],
        )
    else:
        raise ValueError(f"Invalid forward mode: {forward_mode=} for CUDA Graph.")def init_forward_metadata_capture_cuda_graph(
    self,
    bs: int,
    num_tokens: int,
    req_pool_indices: torch.Tensor,
    seq_lens: torch.Tensor,
    encoder_lens: Optional[torch.Tensor],
    forward_mode: ForwardMode,
    spec_info: Optional[SpecInput],
):
    assert encoder_lens is None, "Not supported"
    # 多步推测路径:kv 缓冲来自 spec_info,不经过 replay 路径
    if forward_mode.is_decode_or_idle() and spec_info is not None:
        self.forward_metadata = ForwardMetadata(
            attn_logits=self.cuda_graph_attn_logits,
            attn_lse=self.cuda_graph_attn_lse,
            max_extend_len=None,
            num_kv_splits=self.cuda_graph_num_kv_splits,
            kv_indptr=spec_info.kv_indptr,
            kv_indices=spec_info.kv_indices,
            qo_indptr=None,
            custom_mask=None,
            mask_indptr=None,
        )
        return
    # 常规路径:委托给 replay 完成缓冲更新,然后用 _build_cuda_graph_forward_metadata 冻结元数据
    self.init_forward_metadata_replay_cuda_graph(
        bs=bs,
        req_pool_indices=req_pool_indices,
        seq_lens=seq_lens,
        seq_lens_sum=None,
        encoder_lens=encoder_lens,
        forward_mode=forward_mode,
        spec_info=spec_info,
        seq_lens_cpu=None,
    )
    self.forward_metadata = self._build_cuda_graph_forward_metadata(
        bs, forward_mode, spec_info
    )

评论区精华

WaveAttnBackend 修复 get_num_kv_splits 在 capture 阶段缺失 正确性

PR 作者在 body 中指出 WaveAttnBackend 在 CUDA graph capture 时未调用 get_num_kv_splits,而 replay 时调用,导致 decode 状态不一致。

结论:通过将 capture 委托给 replay 并在之后调用 _build_cuda_graph_forward_metadata 来修复,GB300 验证通过。 · 已解决

多项 CI 测试确认正确性 测试

作者在评论中使用 /rerun-test 触发 test_basic_sanity、test_mla_flashinfer、test_eagle_infer 等多项测试,结果全部通过。

结论:CI 测试结果均为绿色,无失败用例。 · 已解决

风险与影响

主要风险来自对5个attention后端核心CUDA图路径的大幅改动,尽管作者在GB300上完成了全面验证,但其他硬件平台(如H100、AMD MI300等)未经相同范围的覆盖。此外,还有多个后端(flashattention、trtllm等)沿用旧的Pattern C,未来统一时需注意接口和语义的一致性。整体风险可控,但建议在更多CI runner上补充回归测试。

对用户无功能影响,推理结果完全一致。对开发团队,显著降低了CUDA图相关代码的维护成本,新增attention后端可以直接沿用已提取的公共模式。架构上为后续统一所有后端奠定了基础。团队需要继续完成Pattern C后端的迁移,并更新相关内部文档。

核心路径变更 多个后端同步修改 缺少测试配套改动

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论