Prhub

#22218 [Experimental] Breakable Piecewise Cuda Graph

原始 PR 作者 Oasis-Git 合并时间 2026-04-24 19:33 文件变更 8 提交数 52 评论 44 代码增减 +664 / -147

执行摘要

实现不依赖 torch.compile 的可打断 CUDA 图(BCG)

受 #19102 和 @cctry 工作的启发,旨在提供一种更简单的 piecewise CUDA graph 方案,避免对 torch.compile 后端和 FX 图追踪的依赖,降低维护复杂度并提高可调试性。

该 PR 值得所有关注调度优化的开发者精读,尤其是 breakable_cuda_graph_runner.pybreakable_cuda_graph.py 中的图捕获与回放机制。设计上使用 contextvar 和函数装饰器实现图打断,比基于 FX 的路径更直观且易调试。但在生产环境中启用前,应在目标模型架构上进行充分的性能验证,并留意弱引用张量的边界问题。

讨论亮点

Review 中最核心的讨论集中在以下几方面:

  • 参数设计:merrymercy 指出 enable_breakable_cuda_graph 单一参数无法灵活组合解环(prefill)和解码(decode)的模式,建议引入 --cuda-graph-mode 字典或分离参数。
  • 继承关系:merrymercy 和 cctry 均建议不要再从 PiecewiseCudaGraphRunner 继承,最终实现完全独立。
  • radix_attention 复杂度:ispobock 和 cctry 认为 forward 中 BCG/PCG 的分支逻辑使代码复杂,需设计更通用的函数处理 MLA/MHA 差异;merrymercy 建议将处理函数抽离为 processing_mla
  • 弱引用张量安全性:frgossen 提出弱引用张量在共享内存池地址复用可能导致数据损坏,cctry 承认风险存在,表示实际使用中未传递弱引用跨段。
  • 本地脚本:ispobock 要求删除本地测试脚本 run_bcg_comparison.sh,Oasis-Git 已移除。

实现拆解

  1. 新增 Breakable CUDAGraph 核心库 (breakable_cuda_graph/breakable_cuda_graph.py):定义 BreakableCUDAGraph 类,用 _current_capture_var ContextVar 追踪当前捕获上下文;eager_on_graph 装饰器在函数调用处打断当前 CUDA 图并结束该段,开始新段。每个段都是真实的 torch.cuda.CUDAGraph,通过共享内存池的 use_count 管理生命周期,使得弱引用张量在回放时有效。
  2. 新增独立上下文管理模块 (breakable_cuda_graph/context.py):提供 enable_breakable_cuda_graph 上下文管理器和 is_in_breakable_cuda_graph 查询,与 torch.compile 的 piecewise 上下文分离,避免耦合。
  3. 创建 BreakableCudaGraphRunner (breakable_cuda_graph_runner.py):独立于 PiecewiseCudaGraphRunner,但复用其 replay_prepare 方法。实现 _warmup_capture_allcan_run 等关键方法。捕获时调用 BreakableCUDAGraphCapture 上下文管理器,并在每个 attention 层自动调用 eager_on_graph 包装的函数来打断图。
  4. 修改 radix_attention.py 和 nemotron_h.py:在 RadixAttention.forward 中根据是否启用 BCG 分派到 bcg_unified_attention_with_output(通过 eager_on_graph(True) 包装)。类似地在 NemotronH Mamba2 层添加 breakable_nemotron_mamba2_with_output
  5. 修改 model_runner.py 和 server_args.py:在 init_piecewise_cuda_graphs 中根据 enable_breakable_cuda_graph 参数选择使用 BreakableCudaGraphRunner;新增 --enable-breakable-cuda-graph 命令行参数。
  6. 测试和 CI 配置:将测试文件从 test/registered/cuda_graph/ 迁移到 test/registered/breakable_cuda_graph/,新增集成测试(Qwen3-8B + mgsm_en 精度测试)并注册为 large CI 套件。
文件 模块 状态 重要度
python/sglang/srt/model_executor/breakable_cuda_graph_runner.py 调度器 added 9.36
python/sglang/srt/model_executor/breakable_cuda_graph/breakable_cuda_graph.py 基础库 modified 9.21
python/sglang/srt/model_executor/breakable_cuda_graph/context.py 基础库 added 8.09
test/registered/breakable_cuda_graph/test_breakable_cuda_graph.py 测试 renamed 7.46
python/sglang/srt/models/nemotron_h.py 模型适配 modified 6.68
python/sglang/srt/layers/radix_attention.py 注意力层 modified 6.45
python/sglang/srt/model_executor/model_runner.py 模型执行器 modified 6.38
python/sglang/srt/server_args.py 配置入口 modified 5.31

关键符号

BreakableCudaGraphRunner.__init__ BreakableCudaGraphRunner._warmup BreakableCudaGraphRunner._capture_all BreakableCudaGraphRunner.can_run BreakableCUDAGraph.replay BreakableCUDAGraphCapture.__enter__ BreakableCUDAGraphCapture.__exit__ eager_on_graph enable_breakable_cuda_graph is_in_breakable_cuda_graph bcg_unified_attention_with_output breakable_nemotron_mamba2_with_output

关键源码片段

python/sglang/srt/model_executor/breakable_cuda_graph_runner.py data-contract

新增的主要 Runner,封装了 BCG 的 warmup、capture、replay 全流程,与 PCG 平级。

# breakable_cuda_graph_runner.py — BCG Runner 核心class BreakableCudaGraphRunner:
    """不依赖 torch.compile 的可打断 CUDA Graph Runner。"""
​
    # 复用 PCG 的 replay_prepare 方法,避免复制
    replay_prepare = PiecewiseCudaGraphRunner.replay_prepare
​
    def __init__(self, model_runner: ModelRunner):
        self.model_runner = model_runner
        self.device = model_runner.device
        self.capture_num_tokens = sorted(model_runner.server_args.piecewise_cuda_graph_tokens)
        self.graphs: dict[int, BreakableCUDAGraph] = {}
        self._init_buffers(model_runner) # 初始化需要的静态张量
​
    def _init_buffers(self, model_runner: ModelRunner):
        """创建用于捕获的静态输入缓冲区。"""
        # 每个不同 token 数对应一组 dummy 输入(位置编码、attention mask 等)
        # 这些缓冲区在捕获时用作固定输入
        # 代码略 ...
​
    def _warmup(self):
        """运行模型 warmup(不包括图捕获)。"""
        # 执行几次普通前向,触发热加载
        # 代码略 ...
​
    def _capture_all(self):
        """按预设 token 数量逐一捕获图。"""
        for num_tokens in self.capture_num_tokens:
            forward_batch = self._build_capture_forward_batch(num_tokens)
            graph = BreakableCUDAGraph()
            capture_stream = self.device_module.Stream(self.device)
            with enable_breakable_cuda_graph(): # 启用 BCG 上下文
                with BreakableCUDAGraphCapture(graph, stream=capture_stream):
                    with graph_capture(device=self.device): # 开始 CUDA 图捕获
                        self._run_forward(forward_batch) # 执行一次完整前向
            self.graphs[num_tokens] = graph # 保存
​
    def can_run(self, forward_batch: ForwardBatch) -> bool:
        """判断当前 forward_batch 是否可以用已捕获的图执行。"""
        if forward_batch.input_ids is None:
            return False
        num_tokens = forward_batch.input_ids.shape[0]
        idx = bisect.bisect_left(self.capture_num_tokens, num_tokens)
        if idx == len(self.capture_num_tokens):
            return False
        return True
python/sglang/srt/model_executor/breakable_cuda_graph/breakable_cuda_graph.py data-contract

新一代 Piecewise CUDA Graph 基础设施,定义 BreakableCUDAGraph 类、eager_on_graph 装饰器以及捕获上下文。

# breakable_cuda_graph.py — 核心图段管理import logging
import threading
from contextvars import ContextVar
from typing import Any, Callable, Optionalimport torchtry:
    from cuda.bindings import runtime as rt
except ImportError:
    rt = Nonelogger = logging.getLogger(__name__)# 当前线程活跃的 BreakableCUDAGraphCapture 上下文
_current_capture_var: ContextVar[Optional["BreakableCUDAGraphCapture"]] = ContextVar(
    "current_capture", default=None
)
_current_stream_var: ContextVar[Optional[torch.cuda.Stream]] = ContextVar(
    "current_stream", default=None
)
_forked_streams_var: ContextVar[Optional[set[torch.cuda.Stream]]] = ContextVar(
    "forked_streams", default=None
)class BreakableCUDAGraph:
    """管理多个 CUDA Graph 段的容器。每个段都是一个独立的 `torch.cuda.CUDAGraph`。"""
    def __init__(self):
        self.graphs: list[torch.cuda.CUDAGraph] = []
        self.streams: list[torch.cuda.Stream] = []
        self.pool: Optional[torch.cuda.caching_allocator.CUDAPlacedAllocator] = None
        # ... 其他字段
​
    def _begin_capture_segment(self, stream: torch.cuda.Stream):
        """开始一个新段的捕获。"""
        if self.pool is None:
            self.pool = torch.cuda.caching_allocator.CUDAPlacedAllocator()
        self.pool.beginAllocateToPool()
        self.streams.append(stream)
​
    def _end_capture_segment(self, stream: torch.cuda.Stream):
        """结束当前段,自动 join 任何分叉流。"""
        forked = _forked_streams_var.get()
        if forked:
            for s in forked:
                if _is_capturing(s.cuda_stream):
                    torch.cuda.synchronize(s)
            _forked_streams_var.set(set())
        self.pool.endAllocateToPool()class BreakableCUDAGraphCapture:
    """上下文管理器:进入时开始第一个捕获段,退出时结束最后一个段。"""
    def __init__(self, graph: BreakableCUDAGraph, stream: torch.cuda.Stream):
        self.graph = graph
        self.stream = stream
​
    def __enter__(self):
        self._prev = _current_capture_var.get()
        _current_capture_var.set(self)
        _current_stream_var.set(self.stream)
        self.graph._begin_capture_segment(self.stream)
        return self
​
    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is None:
            self.graph._end_capture_segment(self.stream)
        _current_capture_var.set(self._prev)
        _current_stream_var.set(None)def eager_on_graph(break_point: bool = True):
    """装饰器:在捕获模式下,函数调用会结束当前段并开始新段。"""
    def decorator(fn: Callable) -> Callable:
        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            capture = _current_capture_var.get()
            if capture is None:
                return fn(*args, **kwargs)
            capture.graph._end_capture_segment(capture.stream)
            try:
                result = fn(*args, **kwargs)
            finally:
                capture.graph._begin_capture_segment(capture.stream)
            return result
        return wrapper
    return decorator if break_point else (lambda fn: fn)

评论区精华

参数配置设计改进 设计

merrymercy 认为 enable_breakable_cuda_graph 单一参数无法灵活表达 prefill/decode 不同模式,建议引入 --cuda-graph-mode 字典或分离参数。

结论:未解决,当前仅实现单一开关,后续 PR 可能改进。 · unresolved

继承关系与代码复用 设计

merrymercy 和 cctry 均指出不应继承 PiecewiseCudaGraphRunner,最终实现为独立类但复用 replay_prepare 方法。

结论:已采纳:去掉继承,改为组合式复用。 · 已解决

radix_attention 分支复杂化 设计

ispobock 和 merrymercy 认为 forward 中 BCG/PCG 分支使代码复杂,建议抽离为独立函数(如 processing_mla)。

结论:部分解决:BCG 相关逻辑通过 eager_on_graph 和条件分支保持在内,但重构到独立函数的建议未完全采纳。 · 已解决

弱引用张量安全性问题 正确性

frgossen 指出弱引用张量在共享内存池地址重用可能导致数据损坏,cctry 承认风险但表示当前实践避免跨段传递弱引用。

结论:已知风险,当前实现通过不跨段传递弱引用避免,但长期需更安全方案。 · unresolved

本地测试脚本清理 other

ispobock 要求删除本地脚本 run_bcg_comparison.sh。

结论:已删除。 · 已解决

风险与影响

  • 弱引用张量数据安全性:如 frgossen 指出的,弱引用张量可能在池地址复用后指向无效数据,需要确保跨回放段不传播弱引用。(见文件 breakable_cuda_graph.py
  • 性能退化风险:新增的 BCG 路径在 attention 层引入了上下文切换和段捕获开销,虽基准测试与 PCG 持平,但更多模型架构未验证。(radix_attention.py 中的新分支)
  • 与现有 PCG 兼容性:同时维护两套 piecewise 图机制可能引起配置和 bug 修复同步的负担。(model_runner.py 中条件分支)
  • 依赖 'cuda-python':BCG 核心使用 cuda.bindings,若未安装则抛出 ImportError,这增加了一个可选但非无成本的依赖。
  • 启动日志过多:merrymercy 指出每个批次大小都打印日志可能使启动输出冗长。(breakable_cuda_graph_runner.py 中的 tqdm/logging)
  • 开发者:新增的实验性特性使社区可以尝试不使用 torch.compile 的 Piecewise CUDA Graph,降低门槛。但需要维护两人机制,增加了理解和贡献的复杂度。
  • 用户:通过 --enable-breakable-cuda-graph 尝试 BCG,可获得的性能与其环境有关。默认不启用,不影响现有工作流。
  • 系统:在模型加载时多一次捕获过程,消耗额外内存和时间,但回放时无明显开销。测试套件增加了约 130s 的 CI 时间。
  • 团队:需要持续关注两个 piecewise 图路径的同步演进,避免功能偏差。
弱引用张量数据安全性 性能未充分验证 与 PCG 双机制维护负担 依赖 cuda-python 包 启动日志冗长

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论