执行摘要
- 一句话:PP 并行 DeepGEMM JIT 预热,启动时间减少约 60%
- 推荐动作:建议精读以下设计决策:
- batch size 的硬件感知推导方法(从 SM 数量和 block_m 推算 n_splits 区间),可推广到其他类似场景。
_dummy_run 的 forward_mode_override 设计,解耦了 forward mode 与 is_generation,提高可测试性。
- 对 PD 分解模式的优雅处理(根据
disaggregation_mode 跳过不必要的 DECODE/EXTEND)。
- 将并行预热逻辑封装在
compile_utils.py,保持 kernel_warmup 的简洁性。
功能与动机
Today the startup warmup /generate request flows through PP stages serially: stage k can only start its DeepGEMM JIT compile after stage k-1 finishes. For large MoE models like DeepSeek-V4 this dominates startup time.
实现拆解
- 环境变量开关:在
environ.py 中新增 SGLANG_PP_PARALLEL_DEEPGEMM_WARMUP(默认关闭),作为功能的守护条件。
- 调用入口:在
model_runner.py 的 kernel_warmup() 方法中,在 FlashInfer 自动调谐之后,检查环境变量、ENABLE_JIT_DEEPGEMM、pp_size > 1、非 speculative 等条件,满足时调用 pp_parallel_deep_gemm_warmup(self)。
- 核心逻辑:在
compile_utils.py 中新增 pp_parallel_deep_gemm_warmup 函数。该函数根据设备的 SM 数量推导 5 个代表 batch size(使得覆盖 n_splits 从 n_sms 到 1 的各个区间),并根据 disaggregation_mode 决定是否执行 DECODE 和 EXTEND 的 dummy forward。对于每个 bs,通过 _dummy_run(forward_mode_override=ForwardMode.DECODE/EXTEND) 触发对应形状的 JIT 编译。
- _dummy_run 扩展:
_dummy_run 新增 forward_mode_override 参数,替代原来仅靠 is_generation 判断 EXTEND/DECODE 的逻辑,使得生成模型也可以通过 override 执行 EXTEND 模式的 dummy。同时修复了 pp_proxy_tensors 切片条件(由 DSA-specific 改为通用判断)、传递 hc_hidden_size 等,保证 DSv4 等模型的兼容性。
- 配套调整:
compile_utils.py 顶部新增 import time、ForwardMode、ceil_align 等导入,避免引入循环依赖。
关键文件:
python/sglang/srt/model_executor/model_runner.py(模块 模型运行器;类别 source;类型 data-contract;符号 kernel_warmup, _dummy_run): 内核预热入口和 dummy forward 的核心修改,新增 forward_mode_override 参数,修改 pp_proxy_tensors 切片逻辑
python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py(模块 DeepGEMM编译;类别 source;类型 core-logic;符号 pp_parallel_deep_gemm_warmup): 新增 pp_parallel_deep_gemm_warmup 函数,实现并行预热核心逻辑
python/sglang/srt/environ.py(模块 环境配置;类别 source;类型 configuration;符号 SGLANG_PP_PARALLEL_DEEPGEMM_WARMUP): 新增 SGLANG_PP_PARALLEL_DEEPGEMM_WARMUP 环境变量开关
关键符号:kernel_warmup, _dummy_run, pp_parallel_deep_gemm_warmup
关键源码片段
python/sglang/srt/model_executor/model_runner.py
内核预热入口和 dummy forward 的核心修改,新增 forward_mode_override 参数,修改 pp_proxy_tensors 切片逻辑
def kernel_warmup(self):
"""Warmup and tune kernels before cuda graph capture."""
if self.device != "cuda":
return
if self._should_run_flashinfer_autotune():
self._flashinfer_autotune()
# PP-parallel DeepGEMM warmup: 仅在启用环境变量、pp_size > 1、
# JIT DeepGEMM 已开启且非 speculative 模式时执行
if (
envs.SGLANG_PP_PARALLEL_DEEPGEMM_WARMUP.get()
and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and self.pp_size > 1
and not self.spec_algorithm.is_speculative()
):
from sglang.srt.layers.deep_gemm_wrapper.compile_utils import (
pp_parallel_deep_gemm_warmup,
)
pp_parallel_deep_gemm_warmup(self)
python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py
新增 pp_parallel_deep_gemm_warmup 函数,实现并行预热核心逻辑
def pp_parallel_deep_gemm_warmup(model_runner) -> None:
"""为每个 PP rank 执行本地 dummy DECODE 和 EXTEND 前向传播,
使得各 rank 的 DeepGEMM JIT 编译可以并行进行,而不是通过
启动时的 /generate 请求串行触发。
通过环境变量 SGLANG_PP_PARALLEL_DEEPGEMM_WARMUP 启用。
"""
# n_splits ≈ n_sms / ceil(bs / block_m),其中 block_m = 64;
# 选取 5 个代表性的 bs 覆盖实际流量中的各个 n_splits 区间
# (最小 decode 形状、中低、两个中等、n_splits = 1 的 ~5K 预填充)。
# 对 bs 执行 ceil_align 确保满足 attn_cp_size 对齐要求
# (DSA 预填充 CP 需要 seq_len % cp_size == 0)。
n_sms = torch.cuda.get_device_properties(model_runner.device).multi_processor_count
block_m = 64
cp = max(model_runner.attn_cp_size, 1)
batch_sizes = sorted(
{
ceil_align(bs, cp)
for bs in (
1, # n_splits = n_sms
2 * block_m, # 中低区间
max(n_sms // 8, 2) * block_m, # 中等区间
max(n_sms // 4, 4) * block_m, # 中等区间
n_sms * block_m, # n_splits = 1
)
}
)
# PD 分解模式下,prefill-only 节点不应执行 decode dummy,
# decode-only 节点不应执行 extend dummy,避免 OOM 或索引器异常。
disagg_mode = model_runner.server_args.disaggregation_mode
run_decode = model_runner.is_generation and disagg_mode != "prefill"
run_extend = disagg_mode != "decode"
logger.info(
"PP-parallel DeepGEMM warmup start "
"(pp_rank=%d, tp_rank=%d, batch_sizes=%s, disagg=%s).",
model_runner.pp_rank,
model_runner.tp_rank,
batch_sizes,
disagg_mode,
)
t0 = time.perf_counter()
with torch.inference_mode():
for bs in batch_sizes:
if run_decode:
model_runner._dummy_run(
batch_size=bs, forward_mode_override=ForwardMode.DECODE
)
if run_extend:
model_runner._dummy_run(
batch_size=bs, forward_mode_override=ForwardMode.EXTEND
)
logger.info(
"PP-parallel DeepGEMM warmup done in %.2fs (pp_rank=%d).",
time.perf_counter() - t0,
model_runner.pp_rank,
)
评论区精华
- Fridge003: 要求将
_pp_parallel_deep_gemm_warmup 函数移至 compile_utils.py,并添加环境变量保护。作者已采纳,将函数移入 compile_utils.py 并添加 SGLANG_PP_PARALLEL_DEEPGEMM_WARMUP(默认关闭)的 env gate。
- Fridge003 质疑 batch size 的选择。whybeyoung 解释:DeepGEMM 的 grouped GEMM 依据 n_splits ≈ n_sms / ceil(bs/block_m) 选择 JIT 内核,通过 5 个 bs 覆盖所有档位。该解释被接受。
- Fridge003 指出
_dummy_run 中原有 not self.is_generation 判断在引入 forward_mode_override 后应改为 capture_forward_mode == EXTEND。BBuf 确认这一修改不仅正确,还修复了在“生成模型 + override=EXTEND”场景下的潜在 bug。
- Fridge003 建议移除
pp_proxy_tensors 切片中对 is_dsa_enable_prefill_cp() 的依赖,使其通用化。作者在第 8 个提交中移除条件判断,改为通用切片逻辑(attn_cp_size > 1 时执行切片)。
- gemini-code-assist[bot] 建议在异常日志中使用
exc_info=True 以记录完整 traceback。该建议在后续提交中因去掉了 try/except 未被采用,但仍有参考价值。
- 函数位置与环境变量保护 (design): 采用:函数移至 compile_utils.py,新增 SGLANG_PP_PARALLEL_DEEPGEMM_WARMUP 环境变量(默认关闭)。
- Batch size 选取依据 (design): 接受解释,batch size 保持基于 SM 数量的推导。
- is_generation 改为 forward_mode_override 的正确性 (correctness): 修改被接受。
- pp_proxy_tensors 切片条件应通用化 (correctness): 在第 8 个提交中移除 DSA-specific 检查,改为仅在
attn_cp_size > 1 时通用切片。
- 异常日志记录完整 traceback (other): 后续提交去掉了 try/except,异常直接上抛,未采用该建议但仍保留参考价值。
风险与影响
- 风险:
- 回归风险:
_dummy_run 的逻辑修改(forward_mode_override、extend buffer 设置、pp_proxy_tensors 切片)可能影响现有 FlashInfer autotune 路径。但 BBuf 和 Fridge003 已审阅确认兼容,且在 CI 中验证通过。
- 兼容性:新增环境变量默认关闭,不影响现有行为。但启动时若设置该变量,仅在 PP>1 且 DeepGEMM JIT 开启时生效,不涉及其他模型。
- 无测试覆盖:PR 没有添加直接测试文件,依赖现有 CI 和手动验证。对于 edge case(如 PD 分解、CP 配置)可能存在遗漏。
- 资源占用:并行编译可能同时启动多个 DeepGEMM JIT 进程,但受限于 SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS 变量,风险可控。
- 影响:
- 用户影响:DeepSeek-V4 等大型 MoE 模型启动时间可缩短约 60%(从 9-12 分钟降至 3.5-4 分钟),显著提升部署体验。其他模型不受影响。
- 系统影响:无运行时性能变化,仅影响启动阶段。并行编译可能短暂增加 CPU 和 GPU 内存使用,但已在 batch size 和 worker 数量上加以限制。
- 团队影响:新增的内部函数和配置项需要维护;
_dummy_run 的接口扩展(增加可选参数)向后续 warmup 扩展开放。
- 风险标记:核心路径变更(_dummy_run), 缺少测试覆盖, 仅对 DeepGEMM 生效, 默认关闭降低风险
关联脉络
参与讨论