Prhub

#23552 Pre-set SWA cache location in CudaGraphRunner

原始 PR 作者 merrymercy 合并时间 2026-04-24 07:51 文件变更 1 提交数 1 评论 3 代码增减 +27 / -0

执行摘要

CUDA 图捕获前预置 SWA 缓存位置,避免回退到逐层翻译路径。

对于混合 SWA 模型(如 Gemma),原有 CUDA 图捕获时 out_cache_loc_swa 未预置,导致 set_kv_buffer 在回放阶段会回退到每层调用 translate_loc_from_full_to_swa 的慢路径。本 PR 通过预分配张量并在图捕获前设置,强制捕获快速 GPU 路径,提升解码性能。

值得精读。PR 展示了如何在 CUDA 图捕获中通过预置缓冲区强制走快速路径的典型手法,对理解 CUDA 图捕获与 KV 缓存交互有参考价值。建议关注后续的类型清理 PR。

讨论亮点

主要讨论:review 评论指出 out_cache_loc_swa 使用了 torch.int64 类型,但 SWA 缓存位置在池内部使用 int32,因此应该为 int32 以节省内存并保持一致性。作者(merrymercy)承认这是好发现,并指出代码库中多处混合使用了 int32/int64,计划在后续 PR 中统一清理。

实现拆解

  1. 新增数据结构字段:在 DecodeInputBuffers 数据类中添加 out_cache_loc_swa: Optional[torch.Tensor] 字段(文件 cuda_graph_runner.py:141),类型为 torch.int64(注意:review 指出更合理应为 int32,作者承诺后续清理)。

  2. 条件分配张量:在 DecodeInputBuffers.create() 方法中增加 is_hybrid_swa: bool = False 参数,当该标志为 True 时分配 out_cache_loc_swa 张量(形状 (max_num_token,),类型 int64),否则为 Nonecuda_graph_runner.py:183-187)。该参数通过 CudaGraphRunner.__init__model_runner.is_hybrid_swa 传入。

  3. 填充数据:在 populate_from_forward_batch() 方法中,当 self.out_cache_loc_swaforward_batch.out_cache_loc_swa 均不为 None 时,将 forward_batch 中的 SWA 缓存位置数据拷贝到缓冲区中,与已有的 GPU 张量拷贝批次合并,利用 _grouped_foreach_copy_ 按 dtype 分组统一拷贝(cuda_graph_runner.py:364-371)。

  4. 图捕获前预置:在 run_once() 函数捕获 CUDA 图之前,检查 self.buffers.out_cache_loc_swa 是否非 None,若是则调用 self.model_runner.token_to_kv_pool.set_swa_loc() 将该缓冲区地址注册到 KV 池中(cuda_graph_runner.py:1148-1156)。这确保图捕获时 set_kv_buffer 内部的 if self.swa_loc is not None 分支走快速 GPU 操作,而非运行时逐层查表翻译。

文件 模块 状态 重要度
python/sglang/srt/model_executor/cuda_graph_runner.py CUDA 图执行器 modified 6.66

关键符号

DecodeInputBuffers.create DecodeInputBuffers.populate_from_forward_batch CudaGraphRunner.__init__ CudaGraphRunner.run_once

关键源码片段

python/sglang/srt/model_executor/cuda_graph_runner.py data-contract

唯一修改的文件,包含所有变更:新增数据类字段、条件分配、数据填充、图捕获前置逻辑。

# cuda_graph_runner.py 关键变更片段@dataclass
class DecodeInputBuffers(ForwardInputBuffers):
    # ...
    out_cache_loc: torch.Tensor
    out_cache_loc_swa: Optional[torch.Tensor] # 新增:SWA 位置缓存,可选
    # ...
​
    @classmethod
    def create(
        cls,
        # ...
        is_hybrid_swa: bool = False, # 新增参数:是否混合 SWA 模型
    ) -> "DecodeInputBuffers":
        with torch.device(device):
            # ...
            out_cache_loc_swa = (
                torch.zeros((max_num_token,), dtype=torch.int64) # 注意:review 建议改为 int32
                if is_hybrid_swa
                else None
            )
            # ...
        return cls(
            # ...
            out_cache_loc_swa=out_cache_loc_swa,
        )
​
    def populate_from_forward_batch(self, *, forward_batch, ...):
        # ... 已有 GPU 拷贝逻辑后
        # SWA cache location (int32, separate from the int64 batch above)
        if (
            self.out_cache_loc_swa is not None
            and forward_batch.out_cache_loc_swa is not None
        ):
            dsts.append(self.out_cache_loc_swa[:raw_num_token])
            srcs.append(forward_batch.out_cache_loc_swa[:raw_num_token])
        # 与已有拷贝合并同 dtype 批次
        _grouped_foreach_copy_(dsts, srcs)
​
    def run_once(self):
        # ... 捕获前设置 SWA 位置
        if self.buffers.out_cache_loc_swa is not None:
            self.model_runner.token_to_kv_pool.set_swa_loc(
                self.buffers.out_cache_loc_swa[:num_tokens]
            )
        # 之后开始图捕获

评论区精华

out_cache_loc_swa 张量类型应为 int32 而非 int64 正确性

gemini-code-assist[bot] 指出 `out_cache_loc_swa` 分配为 `int64`,但池内部使用 `int32`,建议改为 `int32` 以节省内存并保持一致性。

结论:作者承认问题,并计划在后续 PR 中统一清理代码库中 int32/int64 的混合使用。当前 PR 保留 int64 以尽快合并。 · acknowledged

风险与影响

  1. 类型不匹配风险:新引入的 out_cache_loc_swa 使用 int64 类型,但 KV 池内部使用 int32,可能导致隐式类型转换或占用额外显存。作者已认可并计划后续修复。
  2. 缺少测试覆盖:改动涉及 CUDA 图捕获的关键路径,但未包含对应测试或回归验证,存在引入回归的隐患。
  3. 仅影响混合 SWA 模型is_hybrid_swa 标志默认 False,对其他模型无影响,但新字段在非 SWA 场景下为 None,需检查所有调用点是否兼容可选字段。

影响范围:影响使用混合 SWA 模型的 CUDA 图解码路径(如 Gemma),预计可减少每层翻译的 Python 开销,提升解码吞吐。
影响程度:中等。性能优化针对特定模型家族,改动集中且无回归风险(新字段为可选)。
团队影响:为后续 SWA 相关优化铺平道路,且作者已计划对 int32/int64 类型做统一清理。

类型不匹配 (int64 vs int32) 缺少测试覆盖 仅影响混合 SWA 模型

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论