执行摘要
- 一句话:CUDA图捕获前预置SWA缓存位置,避免回退到逐层翻译路径。
- 推荐动作:值得精读。PR 展示了如何在 CUDA 图捕获中通过预置缓冲区强制走快速路径的典型手法,对理解 CUDA 图捕获与 KV 缓存交互有参考价值。建议关注后续的类型清理 PR。
功能与动机
对于混合 SWA 模型(如 Gemma),原有 CUDA 图捕获时 out_cache_loc_swa 未预置,导致 set_kv_buffer 在回放阶段会回退到每层调用 translate_loc_from_full_to_swa 的慢路径。本 PR 通过预分配张量并在图捕获前设置,强制捕获快速 GPU 路径,提升解码性能。
实现拆解
-
新增数据结构字段:在 DecodeInputBuffers 数据类中添加 out_cache_loc_swa: Optional[torch.Tensor] 字段(文件 cuda_graph_runner.py:141),类型为 torch.int64(注意:review 指出更合理应为 int32,作者承诺后续清理)。
-
条件分配张量:在 DecodeInputBuffers.create() 方法中增加 is_hybrid_swa: bool = False 参数,当该标志为 True 时分配 out_cache_loc_swa 张量(形状 (max_num_token,),类型 int64),否则为 None(cuda_graph_runner.py:183-187)。该参数通过 CudaGraphRunner.__init__ 从 model_runner.is_hybrid_swa 传入。
-
填充数据:在 populate_from_forward_batch() 方法中,当 self.out_cache_loc_swa 和 forward_batch.out_cache_loc_swa 均不为 None 时,将 forward_batch 中的 SWA 缓存位置数据拷贝到缓冲区中,与已有的 GPU 张量拷贝批次合并,利用 _grouped_foreach_copy_ 按 dtype 分组统一拷贝(cuda_graph_runner.py:364-371)。
-
图捕获前预置:在 run_once() 函数捕获 CUDA 图之前,检查 self.buffers.out_cache_loc_swa 是否非 None,若是则调用 self.model_runner.token_to_kv_pool.set_swa_loc() 将该缓冲区地址注册到 KV 池中(cuda_graph_runner.py:1148-1156)。这确保图捕获时 set_kv_buffer 内部的 if self.swa_loc is not None 分支走快速 GPU 操作,而非运行时逐层查表翻译。
关键文件:
python/sglang/srt/model_executor/cuda_graph_runner.py(模块 CUDA图执行器;类别 source;类型 data-contract;符号 DecodeInputBuffers, DecodeInputBuffers.create, DecodeInputBuffers.populate_from_forward_batch, CudaGraphRunner.init): 唯一修改的文件,包含所有变更:新增数据类字段、条件分配、数据填充、图捕获前置逻辑。
关键符号:DecodeInputBuffers.create, DecodeInputBuffers.populate_from_forward_batch, CudaGraphRunner.init, CudaGraphRunner.run_once
关键源码片段
python/sglang/srt/model_executor/cuda_graph_runner.py
唯一修改的文件,包含所有变更:新增数据类字段、条件分配、数据填充、图捕获前置逻辑。
# cuda_graph_runner.py 关键变更片段
@dataclass
class DecodeInputBuffers(ForwardInputBuffers):
# ...
out_cache_loc: torch.Tensor
out_cache_loc_swa: Optional[torch.Tensor] # 新增:SWA 位置缓存,可选
# ...
@classmethod
def create(
cls,
# ...
is_hybrid_swa: bool = False, # 新增参数:是否混合 SWA 模型
) -> "DecodeInputBuffers":
with torch.device(device):
# ...
out_cache_loc_swa = (
torch.zeros((max_num_token,), dtype=torch.int64) # 注意:review 建议改为 int32
if is_hybrid_swa
else None
)
# ...
return cls(
# ...
out_cache_loc_swa=out_cache_loc_swa,
)
def populate_from_forward_batch(self, *, forward_batch, ...):
# ... 已有 GPU 拷贝逻辑后
# SWA cache location (int32, separate from the int64 batch above)
if (
self.out_cache_loc_swa is not None
and forward_batch.out_cache_loc_swa is not None
):
dsts.append(self.out_cache_loc_swa[:raw_num_token])
srcs.append(forward_batch.out_cache_loc_swa[:raw_num_token])
# 与已有拷贝合并同 dtype 批次
_grouped_foreach_copy_(dsts, srcs)
def run_once(self):
# ... 捕获前设置 SWA 位置
if self.buffers.out_cache_loc_swa is not None:
self.model_runner.token_to_kv_pool.set_swa_loc(
self.buffers.out_cache_loc_swa[:num_tokens]
)
# 之后开始图捕获
评论区精华
主要讨论:review 评论指出 out_cache_loc_swa 使用了 torch.int64 类型,但 SWA 缓存位置在池内部使用 int32,因此应该为 int32 以节省内存并保持一致性。作者(merrymercy)承认这是好发现,并指出代码库中多处混合使用了 int32/int64,计划在后续 PR 中统一清理。
- out_cache_loc_swa 张量类型应为 int32 而非 int64 (correctness): 作者承认问题,并计划在后续 PR 中统一清理代码库中 int32/int64 的混合使用。当前 PR 保留 int64 以尽快合并。
风险与影响
关联脉络
- PR #23545 Fix MoE no_combine: skip router weight in down projection: 同属对 CUDA 图捕获路径的优化,涉及 KV 缓存操作。
- PR #23588 [PD+DP] Allow PrefillDelayer in disaggregated-prefill mode: 同样涉及调度与缓存位置管理,与本 PR 的 SWA 位置预置有间接关联。
参与讨论