Prhub

#26742 [refactor] Unify CUDA graph runner input buffers behind CudaGraphBufferRegistry

原始 PR 作者 ch-wan 合并时间 2026-06-04 01:54 文件变更 6 提交数 30 评论 3 代码增减 +2242 / -249

执行摘要

统一 CUDA Graph 输入缓冲区管理

三个 CUDA graph runner(CudaGraphRunner, PiecewiseCudaGraphRunner, BreakableCudaGraphRunner)各自维护独立的手写输入缓冲区填充与复制逻辑,逐字段地操作 ForwardBatch 属性,代码重复且易错。PR 通过引入声明式注册表统一管理这些缓冲区,降低维护成本和新增功能时的出错概率。

推荐所有涉及 CUDA graph 性能优化或模型推理的开发者阅读。该 PR 展示了如何通过声明式注册表简化复杂的手写数据搬运逻辑,设计模式值得借鉴。特别关注未解决的 DSV4 replay 正确性问题,建议团队尽快安排后续修复或至少更新文档说明限制。对于不涉及 DSV4 的模型,重构风险较低可正常合入。

讨论亮点

DSV4 重放路径正确性争议
Reviewer chatgpt-codex-connector[bot]piecewise_cuda_graph_runner.py 的 replay 方法处提出 P1 级别问题:重构后 init_forward_metadata 仅在捕获时执行,而原代码在 layer 循环前有 _maybe_upgrade_forward_metadata() 调用,现在缺失,导致 DeepSeek V4 的 forward_metadata 类型可能未被正确升级,在模型前向推导时解引用 core_metadata 等字段会崩溃。该评论至今未获作者回应,PR 已合并,风险未被修复。

实现拆解

步骤1:定义缓冲区注册表核心cuda_graph_buffer_registry.py
引入 PaddingPolicy 枚举(5 种策略:KEEP_PAD、FILL_SENTINEL、ZERO、FOREACH_COPY、FILL_ONCE)、FillContext 数据类、GraphSlot 数据类(描述一个 ForwardBatch 字段对应的缓冲区,包含名称、形状函数、dtype、轴、填充策略、可选的 post_fill 钩子、slice_fn、source_fn)以及 CudaGraphBufferRegistry 类。注册表通过 register_slot 注册插槽(自动分配或通过 bind 采用现有张量),fill_from 方法按 (raw < padded) 情况执行填充重置、按 dtype 分组的 D2D 复制、post_fill 钩子。同时提供 build_decode_registry(bs 轴)和 build_prefill_registry(token 轴)工厂,封装了常见字段的插槽声明。

步骤2:重构输入缓冲区池input_buffers.py
ForwardInputBuffers._share_one_buffer 方法提升为模块级别函数 share_input_buffer,其键由纯名称改为 (name, numel, dtype, device) 四元组,确保相同规格的缓冲区始终共享同一物理内存,消除注册顺序对共享结果的影响。同时保留 ForwardInputBuffers 作为数据类包装。

步骤3:迁移全图解码 Runnercuda_graph_runner.py
CudaGraphRunner__init__ 中调用 build_decode_registry(source=self.buffers) 构建注册表,采用现有 DecodeInputBuffers 的存储。populate_from_forward_batch 方法逐步移除手写复制/填充逻辑,最终仅委托给 registry.fill_from。捕获阶段通过 registry.get_slot(name).slice_for(...) 获取输入视图,替代原有的缓冲区直接切片。

步骤4:迁移分片/可中断 Runnerpiecewise_cuda_graph_runner.pybreakable_cuda_graph_runner.py
类似地,这两个 runner 在初始化时调用 build_prefill_registry(source=self.buffers) 构建 token 轴注册表。replay_prepare 由内联填充/复制替换为 registry.fill_from,捕获阶段从注册表获取输入切片。可中断 runner 复用分片 runner 的注册表构建。所有手动维护的 PrefillInputBuffers 相关逻辑被移除。

步骤5:单元测试覆盖test_cuda_graph_buffer_registry.py
新增 33 个 CPU 单元测试,覆盖插槽注册、每种填充策略的初始化/重置/复制、fill_fromextract_buffer、post_fill 钩子、source_fn、两个工厂函数以及共享池语义。测试使用 cpu 设备,全面验证注册表逻辑而不依赖 GPU 环境。

文件 模块 状态 重要度
python/sglang/srt/model_executor/cuda_graph_buffer_registry.py 缓冲区注册表 added 9.36
python/sglang/srt/model_executor/cuda_graph_runner.py 图形执行器 modified 8.75
test/registered/unit/model_executor/test_cuda_graph_buffer_registry.py 单元测试 added 8.26
python/sglang/srt/model_executor/input_buffers.py 输入缓冲区 modified 8.04
python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py 分片执行器 modified 7.73
python/sglang/srt/model_executor/breakable_cuda_graph_runner.py 可中断执行器 modified 7.28

关键符号

_grouped_foreach_copy_ CudaGraphBufferRegistry.fill_from CudaGraphBufferRegistry.register_slot build_decode_registry build_prefill_registry share_input_buffer GraphSlot.slice_for

关键源码片段

python/sglang/srt/model_executor/cuda_graph_buffer_registry.py data-contract

核心新文件,定义 CudaGraphBufferRegistry、GraphSlot、PaddingPolicy 等,是重构的基石。

from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.model_executor.input_buffers import share_input_buffer
if TYPE_CHECKING:
    from sglang.srt.model_executor.forward_batch_info import ForwardBatch_has_foreach_copy = hasattr(torch, '_foreach_copy_')def _grouped_foreach_copy_(dsts: List[torch.Tensor], srcs: List[torch.Tensor]) -> None:
    # 将多个 D2D 复制按 (dst.dtype, src.dtype) 分组,因为 torch._foreach_copy_ 要求同组 dtype 一致
    def _foreach_copy(group_dsts, group_srcs):
        if _has_foreach_copy:
            torch._foreach_copy_(group_dsts, group_srcs)
        else:
            for dst, src in zip(group_dsts, group_srcs):
                dst.copy_(src)
    groups: Dict[Tuple[torch.dtype, torch.dtype], Tuple[List, List]] = {}
    for dst, src in zip(dsts, srcs):
        key = (dst.dtype, src.dtype)
        if key not in groups:
            groups[key] = ([], [])
        groups[key][0].append(dst)
        groups[key][1].append(src)
    for group_dsts, group_srcs in groups.values():
        _foreach_copy(group_dsts, group_srcs)class PaddingPolicy(Enum):
    # raw < padded 时如何处理尾部填充区域
    KEEP_PAD = 'keep_pad'
    FILL_SENTINEL = 'fill_sentinel'
    ZERO = 'zero'
    FOREACH_COPY = 'foreach_copy'
    FILL_ONCE = 'fill_once'@dataclass
class FillContext:
    raw_bs: int
    padded_bs: int
    raw_num_tokens: int
    padded_num_tokens: int
    pp_proxy_tensors: Optional[Any] = None@dataclass
class GraphSlot:
    name: str
    shape_fn: Callable[[int, int], Tuple]
    dtype: torch.dtype
    axis: str # 'bs'/'tokens'/'none'
    device: Optional[torch.device] = None
    padding_policy: PaddingPolicy = PaddingPolicy.KEEP_PAD
    pad_value: int = 0
    post_fill: Optional[Callable] = None
    slice_fn: Optional[Callable] = None
    source_fn: Optional[Callable] = None
    copy_from_fb: bool = True
    buffer: Optional[torch.Tensor] = None
​
    def __post_init__(self):
        if self.axis not in ('bs', 'tokens', 'none'):
            raise ValueError(f'axis must be bs/tokens/none, got {self.axis}')
        if self.axis == 'none' and self.padding_policy != PaddingPolicy.KEEP_PAD:
            raise ValueError('axis=none only supports KEEP_PAD')
​
    def slice_for(self, padded_bs: int, padded_num_tokens: int) -> torch.Tensor:
        if self.buffer is None:
            raise RuntimeError('buffer not yet allocated')
        if self.axis == 'bs':
            return self.buffer[:padded_bs]
        elif self.axis == 'tokens':
            return self.buffer[:padded_num_tokens]
        else:
            return self.buffer
python/sglang/srt/model_executor/cuda_graph_runner.py core-logic

主修改文件,全图解码 runner 通过注册表重写 populate 和 capture 路径,代码量减少 144 行。

# cuda_graph_runner.py ( 关键改动摘录 )from sglang.srt.model_executor.cuda_graph_buffer_registry import build_decode_registryclass CudaGraphRunner:
    def __init__(self, model_runner):
        # ... 其他初始化 ...
        self.buffers.share_buffers()
        # 构建 bs 轴注册表,复用 DecodeInputBuffers 的存储
        self.buffer_registry = build_decode_registry(
            device=self.device,
            max_bs=self.max_bs,
            max_num_token=self.max_num_tokens,
            cache_loc_dtype=self._cache_loc_dtype(),
            seq_len_fill_value=self.model_runner.seq_len_fill_value,
            is_encoder_decoder=self.model_runner.is_encoder_decoder,
            encoder_len_fill_value=self.model_runner.encoder_len_fill_value,
            enable_mamba_track=self.mamba_track_enabled,
            source=self.buffers,
        )
​
    def populate_from_forward_batch(self, forward_batch, ...):
        self.buffer_registry.fill_from(
            forward_batch,
            bs=bs,
            raw_bs=raw_bs,
            num_tokens=num_tokens,
            seq_len_fill_value=seq_len_fill_value,
            capture_forward_mode=capture_forward_mode,
            pp_proxy_tensors=pp_proxy_tensors,
        )

评论区精华

DSV4 piecewise replay 中 metadata 升级缺失 正确性

reviewer chatgpt-codex-connector[bot] 在 piecewise_cuda_graph_runner.py 的 replay 方法评论指出,重构后 init_forward_metadata 仅在捕获时执行,重放时移除了 _maybe_upgrade_forward_metadata 调用,可能导致 DeepseekV4AttnBackend 的 forward_metadata 类型错误,在模型前向推导时崩溃。

结论:PR 作者未回复该评论,PR 已合并,正确性问题未在本次重构中解决。 · unresolved

风险与影响

  • DSV4 replay 正确性风险:如上所述,分片 Runner 的 metadata 初始化路径被抽离,用户若使用 DSV4 并启用 piecewise CUDA graph(SGLANG_PREP_IN_CUDA_GRAPH=1)可能出现崩溃。影响面为 DSV4 用户且开启该功能的情况。
  • 共享池语义改变风险:输入缓冲区键从名称改为四元组,若存在相同名称但不同大小的缓冲区,它们将不再共享同一张量,可能略微增加内存占用,但不会引起功能错误。
  • 回归风险:重构涉及三个 runner 的核心路径,尽管设计为行为保持,但任何复制逻辑中的边界条件、dtype 分组顺序细微偏差都可能导致数据不一致。新测试仅覆盖注册表逻辑单元,未进行 GPU 集成验证,已有的集成测试可能不足。
  • 依赖 torch._foreach_copy_ 的风险:在 CUDA 上依赖该函数实现高效复制,若不可用则 fallback 到逐张量 copy,回退路径仅通过 CPU 测试,CUDA 上的性能可能劣化。
  • 用户影响:无 API 变化,E2E 行为不变,但 DSV4 用户存在潜在的崩溃风险。
  • 系统影响:新增注册表模块和函数式共享池,扩展(如自定义 runner)需要适配新接口;未来新增 ForwardBatch 字段只需在工厂函数中注册 GraphSlot,无需修改所有 runner 的复制逻辑,维护成本降低。
  • 团队影响:代码集中化、可预测性提高,但团队成员需要学习注册表概念;commit 历史(30 步递增)便于审阅和回退。
DSV4 replay 正确性未验证 共享池键变更可能影响兼容性 重构覆盖核心图执行路径 缺少 GPU 集成测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论