执行摘要
- 一句话:支持 bs>1 的可中断 CUDA 图执行
- 推荐动作:值得精读。该 PR 体现了深刻的技术洞察:通过重新划定 CUDA 图捕获边界,使图与 batch size 解耦,是使图化预填充支持多请求的关键设计。代码改动简洁(仅 1 文件 +84/-57),但思路值得借鉴。建议关注后续改进 layer_model 解析的多模型兼容性和测试覆盖。
功能与动机
原有的 BreakableCudaGraphRunner 捕获整个模型 forward (包括 logits_processor 和 pooler),这些核的形状依赖于 batch size,导致图只能在 bs=1 时有效。为了支持多请求预填充时依然可以在 CUDA 图上执行,需要将捕获边界缩小至不依赖 batch size 的内部 transformer 层。
实现拆解
实现分为四步:
-
解析 layer_model (init):在初始化时通过 language_model = getattr(model_runner.model, 'language_model', model_runner.model) 和后续的 language_model.model 解析出内部 transformer 栈模块(与 PiecewiseCudaGraphRunner 的 patch_model 边界一致),并赋值给 self.layer_model。此方法依赖模型属性名称,存在鲁棒性问题(见 review)。
-
修改 _run_forward:将 model_runner.model.forward 调用替换为 self.layer_model.forward,只执行 transformer 栈。添加 @torch.no_grad 装饰器以匹配外部 ForCausalLM.forward 的无梯度模式。
-
修改 _build_capture_forward_batch:构建占位的 forward_batch 时使用 bs=1(作为 attention/mamba 分段元数据形状的占位),实际的 bs>1 元数据由 replay_prepare 在 replay 时注入。
-
清理与简化:移除原先为 bs 参数静态分配的 static_seq_lens 等一批形状为 (max_bs,) 的张量,以及验证计数器 replay/can_run_reject 和相关日志。
无其他文件变更,无测试文件配套(但原机制已有覆盖)。
关键文件:
python/sglang/srt/model_executor/breakable_cuda_graph_runner.py(模块 执行引擎;类别 source;类型 core-logic;符号 init, _run_forward, _build_capture_forward_batch, _init_buffers): 核心变更文件,实现将 CUDA 图捕获边界缩小到 layer_model,支持 bs>1。
关键符号:init, _run_forward, _build_capture_forward_batch, _init_buffers
关键源码片段
python/sglang/srt/model_executor/breakable_cuda_graph_runner.py
核心变更文件,实现将 CUDA 图捕获边界缩小到 layer_model,支持 bs>1。
# __init__ 中新增的 layer_model 解析(替换原有的静态 bs 张量分配)
def __init__(self, model_runner: ModelRunner):
# ... 前面代码不变
# 解析内部 transformer 栈模块,边界与 PCG patch_model 一致
language_model = getattr(
model_runner.model, "language_model", model_runner.model
)
self.layer_model = (
language_model.model
if hasattr(language_model, "model")
and hasattr(language_model.model, "layers")
else language_model
)
# 注意:此方式依赖模型属性命名,不够鲁棒,review 已指出
# Memory pool ( 不变 )
if get_global_graph_memory_pool() is None:
set_global_graph_memory_pool(self.device_module.graph_pool_handle())
set_graph_pool_id(get_global_graph_memory_pool())
# Warmup / capture(不变)
self._warmup()
self.device_module.synchronize()
self.model_runner.tp_group.barrier()
self._capture_all()
self.raw_num_tokens = 0
# _run_forward 使用 layer_model.forward,并添加 @torch.no_grad
@torch.no_grad() # 新增,匹配外部的 torch.no_grad 装饰
def _run_forward(self, forward_batch, num_tokens):
"""只执行内部 transformer 栈前向,避免 bs 相关形状固化。
"""
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(None, num_tokens, forward_batch.dp_padding_mode.is_max_len())
set_is_extend_in_batch(False)
with set_forward_context(
forward_batch,
self.attention_layers,
self.quant_config,
self.moe_layers,
self.moe_fusions,
):
output = self.layer_model.forward( # 原为 model_runner.model.forward
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
)
return output
评论区精华
Reviewer merrymercy 指出 layer_model 的解析方式依赖于字符串名称匹配(model_runner.model.language_model.model),非常脆弱,建议至少应在匹配失败时发出警告。PR 作者未公开回复此评论,未添加警告,但 PR 最终被合并。该设计决策属于技术债务,需后续改进。
- layer_model 解析依赖字符串名称匹配不鲁棒 (design): PR 作者未公开回应,未添加 warning,但 PR 最终被合并,该设计被接受为初期实现。
风险与影响
关联脉络
参与讨论