执行摘要
- 一句话:DSV4 MHC pre 启动预热,消除冷启动尾部延迟
- 推荐动作:建议阅读
model_runner.py 中的钩子扩展方式和 mhc.py 中代表桶计算逻辑,这是模型特定预热与框架解耦的典型设计。此外,prewarm_mhc_token_counts 方法对显存和时间的权衡(通过 del 及时释放)也值得借鉴。本 PR 无显著风险,可正常合入。
功能与动机
DeepSeek V4 MHC pre 在稀有的 token-count bucket 首次进入服务进程时,会触发 20-40s 的前向 stall(PR body 原文:'In internal DSV4 profiling runs, the long 20-40s forward stalls aligned with first-seen MHC pre split buckets in sparse mixed/decode tail.')。这严重影响尾延迟和吞吐量稳定性。本 PR 将首次加载成本转移到启动预热阶段,使线上流量和基准测试不再为此买单。
实现拆解
-
新增 token-count 代表桶计算(python/sglang/srt/layers/mhc.py):定义 get_mhc_pre_token_count_representatives 函数,基于 _compute_num_split_for_mhc_pre 枚举每个 split bucket,并取每个 bucket 中最大的 token count 作为代表值,避免预热所有 token count。
-
DecoderLayer 层预热方法(python/sglang/srt/models/deepseek_v4.py 内 DeepseekV4DecoderLayer):新增 prewarm_mhc_token_counts 方法,遍历 attn 和 ffn 两条路径,对每个代表 token count 构造虚拟 residual 张量,调用 self.hc_pre,同步 CUDA 后释放。prewarm_mhc_token_count_buckets 方法负责调用代表桶计算后触发预热并返回使用的 token counts。
-
模型级预热 hook(python/sglang/srt/models/deepseek_v4.py 内 DeepseekV4ForCausalLM):新增 kernel_warmup(self, model_runner) 方法,作为模型专属预热钩子。它首先检查 is_hybrid_swa 及两个环境变量开关(SGLANG_OPT_DEEPGEMM_HC_PRENORM、SGLANG_OPT_USE_TILELANG_MHC_PRE),然后从 model_runner.server_args.chunked_prefill_size 推导 max_num_tokens(DP attention 下自动被 DP size 归一),最后委托给第一层 decoder layer 的 prewarm_mhc_token_count_buckets。
-
NextN 模型暴露(python/sglang/srt/models/deepseek_v4_nextn.py):DeepseekV4NextnDecoderLayer 新增 prewarm_mhc_token_count_buckets 委托给内部的 self.decoder,使得 NextN 包装路径也能触发预热。
-
ModelRunner 框架扩展(python/sglang/srt/model_executor/model_runner.py):修改 kernel_warmup 方法,在原有的 FlashInfer autotune 之后,通过 getattr(self.model, 'kernel_warmup', None) 调用模型专属钩子,保持框架通用性。
-
配置与守卫:无新增用户配置项;预热范围直接复用 --chunked-prefill-size,仅当两个环境变量同时启用且为 hybrid SWA 时才执行,不影响其他模型或功能。
关键文件:
python/sglang/srt/models/deepseek_v4.py(模块 模型层;类别 source;类型 core-logic;符号 prewarm_mhc_token_counts, prewarm_mhc_token_count_buckets, kernel_warmup): 核心预热逻辑所在:在 DeepseekV4DecoderLayer 中新增 prewarm_mhc_token_counts 和 prewarm_mhc_token_count_buckets,在 DeepseekV4ForCausalLM 中新增 kernel_warmup 钩子,以及模型级 prewarm_mhc_token_count_buckets 委托。整个预热的实现骨架和条件守卫均在此文件。
python/sglang/srt/layers/mhc.py(模块 MHC 层;类别 source;类型 core-logic;符号 get_mhc_pre_token_count_representatives): 新增 get_mhc_pre_token_count_representatives 函数,它是预热逻辑的“调度”核心:通过遍历 1..max_num_tokens 并调用已有的 _compute_num_split_for_mhc_pre 为每个 split bucket 选取代表 token count,确保每个 bucket 只预热一个代表值,大幅减少预热总次数。
python/sglang/srt/model_executor/model_runner.py(模块 模型执行器;类别 source;类型 data-contract): 扩展了 kernel_warmup 框架方法,在原有的 FlashInfer autotune 之后,通过 getattr(self.model, 'kernel_warmup', None) 触发模型专属预热钩子。这一变化是框架与模型解耦的关键设计点。
python/sglang/srt/models/deepseek_v4_nextn.py(模块 模型层;类别 source;类型 data-contract;符号 prewarm_mhc_token_count_buckets): 为 NextN 包装模型新增 prewarm_mhc_token_count_buckets 委托方法,确保使用 NextN 路径时预热也能正常触发。改动很小但保证了功能完整性。
关键符号:DeepseekV4DecoderLayer.prewarm_mhc_token_counts, DeepseekV4DecoderLayer.prewarm_mhc_token_count_buckets, DeepseekV4ForCausalLM.kernel_warmup, DeepseekV4ForCausalLM.prewarm_mhc_token_count_buckets, get_mhc_pre_token_count_representatives, ModelRunner.kernel_warmup
关键源码片段
python/sglang/srt/models/deepseek_v4.py
核心预热逻辑所在:在 DeepseekV4DecoderLayer 中新增 prewarm_mhc_token_counts 和 prewarm_mhc_token_count_buckets,在 DeepseekV4ForCausalLM 中新增 kernel_warmup 钩子,以及模型级 prewarm_mhc_token_count_buckets 委托。整个预热的实现骨架和条件守卫均在此文件。
# python/sglang/srt/models/deepseek_v4.py
class DeepseekV4DecoderLayer(nn.Module):
# ... 省略现有 __init__ 和 hc_pre ...
def prewarm_mhc_token_counts(
self, token_counts: Tuple[int, ...], device: torch.device
) -> None:
"""对每个代表 token count 分别运行 attn 和 ffn 路径的 hc_pre。"""
paths = (
("attn", self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base, self.input_layernorm),
("ffn", self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base, self.post_attention_layernorm),
)
with torch.inference_mode():
for num_tokens in token_counts:
for path_name, hc_fn, hc_scale, hc_base, norm in paths:
tic = time.perf_counter()
# 创建 dummy 输入,shape 与真实推理保持一致
residual = torch.empty(
(num_tokens, self.hc_mult, self.hidden_size),
dtype=torch.bfloat16,
device=device,
)
y, post, comb, _ = self.hc_pre(residual, hc_fn, hc_scale, hc_base, norm=norm)
# 及时释放,避免 OOM
del residual, y, post, comb
torch.cuda.synchronize()
logger.info(
"DeepSeek V4 MHC prewarm path=%s num_tokens=%s completed in %.3fs",
path_name, num_tokens, time.perf_counter() - tic,
)
def prewarm_mhc_token_count_buckets(
self, max_num_tokens: int, device: torch.device
) -> Tuple[int, ...]:
"""计算代表 token counts 并执行预热,返回实际使用的 token counts。"""
from sglang.srt.layers.mhc import get_mhc_pre_token_count_representatives
token_counts = get_mhc_pre_token_count_representatives(
max_num_tokens, self.hc_mult * self.hidden_size
)
if not token_counts:
return token_counts
logger.info(
"DeepSeek V4 MHC prewarm max_num_tokens=%s representative token counts: %s",
max_num_tokens, token_counts,
)
self.prewarm_mhc_token_counts(token_counts, device)
return token_counts
python/sglang/srt/layers/mhc.py
新增 get_mhc_pre_token_count_representatives 函数,它是预热逻辑的“调度”核心:通过遍历 1..max_num_tokens 并调用已有的 _compute_num_split_for_mhc_pre 为每个 split bucket 选取代表 token count,确保每个 bucket 只预热一个代表值,大幅减少预热总次数。
# python/sglang/srt/layers/mhc.py
def get_mhc_pre_token_count_representatives(
max_num_tokens: int, hc_hidden_size: int
) -> Tuple[int, ...]:
"""Return one token-count representative for each MHC pre split bucket.
MHC pre 内部会根据 token count 和 hidden size 决定 split 数量 (n_splits)。
同一个 n_splits 的 token count 共享相同的 kernel 编译结果,因此只需预热
每个 bucket 内的最大 token count(即 bucket 触发的最后一个 token count),
即可覆盖整个 bucket 的编译开销。
"""
if max_num_tokens <= 0:
return tuple()
representatives_by_split: dict[int, int] = {}
for num_tokens in range(1, max_num_tokens + 1):
n_splits = _compute_num_split_for_mhc_pre(num_tokens, hc_hidden_size)
# 同一个 n_splits 只保留最大的 num_tokens
representatives_by_split[n_splits] = num_tokens
return tuple(
representatives_by_split[n_splits]
for n_splits in sorted(representatives_by_split)
)
python/sglang/srt/model_executor/model_runner.py
扩展了 kernel_warmup 框架方法,在原有的 FlashInfer autotune 之后,通过 getattr(self.model, 'kernel_warmup', None) 触发模型专属预热钩子。这一变化是框架与模型解耦的关键设计点。
# python/sglang/srt/model_executor/model_runner.py
class ModelRunner:
# ...
def kernel_warmup(self):
"""
Warmup and tune kernels before cuda graph capture.
Covers framework-level warmups and optional model-specific warmups.
"""
if self.device != "cuda":
return
if self._should_run_flashinfer_autotune():
self._flashinfer_autotune()
# 模型可以通过定义同名 hook 注册自己的预热逻辑(如 DeepSeek V4 MHC pre)
model_kernel_warmup = getattr(self.model, "kernel_warmup", None)
if model_kernel_warmup is not None:
model_kernel_warmup(self)
评论区精华
审阅者 Fridge003 在评论中指出 prewarm_mhc_token_counts 内部调用了 deep_gemm.tf32_hc_prenorm_gemm 的 JIT 编译,建议后续将其包装到 python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py 中以复用其他 DeepGEMM 核的预热管道,但认为本 PR 较为紧急,可留待未来改进。该 thread 状态为已关闭,未产生分歧。
- 是否将 deep_gemm.tf32_hc_prenorm_gemm 包装到 deep_gemm_wrapper 中 (design): 认同该方向,但 PR 紧急,留待未来优化(deferred)。
风险与影响
关联脉络
- PR #25646 fix deepseek v4 hisparse: 同为 DSV4 MHC 模块的 bugfix,修改
mhc.py 中的压缩器逻辑,与本次预热功能在同一功能线上。
- PR #25860 add git gemm warpper for dispatch_bf16_fp32_backend: 引入 DeepGemm wrapper,与 Fridge003 建议将 MHC pre 内核包装到 wrapper 的思路同向。
参与讨论