Prhub

#23321 [sgl] reduce specdec cpu overhead

原始 PR 作者 2022tgoel 合并时间 2026-05-05 06:02 文件变更 6 提交数 4 评论 2 代码增减 +93 / -61

执行摘要

拆分 top-k 选择函数减少 specdec CPU 开销

根据PR标题和修改内容,主要动机是减少speculative decoding过程中的CPU overhead,提升推理性能。

值得阅读,尤其是拆分torch.compile函数以减少编译开销的模式。开发者可参考此方法优化其他类似分支函数。

讨论亮点

审查者Qiaolin-Yu询问为什么该PR与NPU相关,并要求提供torch profiling结果。作者未在评论区回应,但最终获得批准,可能在线下沟通。

实现拆解

  1. 拆分spec_utils.py中的select_top_k_tokens:原函数包含if i==0分支,整体被@torch.compile装饰。拆分为两个子函数:_select_top_k_tokens_first(无torch.compile)处理第一个步骤,_select_top_k_tokens_later(保留@torch.compile)处理后续步骤,原函数变为轻量路由。这样避免了torch.compile在动态分支下的重新编译开销。
  2. 优化eagle_info_v2.py的prepare_for_decode:将KV长度计算从动态append改为预分配列表并直接索引赋值,减少Python层面的内存操作和循环开销。
  3. 为alloc_extend系列函数添加可选num_new_pages参数:在allocator.py、allocator_npu.py、swa_memory_pool.py、hisparse_memory_pool.py中,允许调用方传入预先计算好的页面数,避免在函数内重复计算,减少CPU开销。
文件 模块 状态 重要度
python/sglang/srt/speculative/spec_utils.py 推测解码 modified 8.11
python/sglang/srt/speculative/eagle_info_v2.py 推测解码 modified 6.05
python/sglang/srt/mem_cache/allocator.py 内存管理 modified 5.57
python/sglang/srt/hardware_backend/npu/allocator_npu.py 内存管理 modified 5.62
python/sglang/srt/mem_cache/swa_memory_pool.py 内存管理 modified 4.49
python/sglang/srt/mem_cache/hisparse_memory_pool.py 内存管理 modified 3.95

关键符号

select_top_k_tokens _select_top_k_tokens_first _select_top_k_tokens_later prepare_for_decode

关键源码片段

python/sglang/srt/speculative/spec_utils.py core-logic

核心变更文件,将 select_top_k_tokens 拆分为两个独立函数以减少 torch.compile 开销。

def _select_top_k_tokens_first(
    topk_p: torch.Tensor,
    topk_index: torch.Tensor,
    hidden_states: Optional[torch.Tensor],
    topk: int,
):
    # 首步选择:直接将 topk_index flatten 作为候选 token ID
    # 并 repeat_interleave hidden_states 以匹配 topk 展开
    input_ids = topk_index.flatten()
    if hidden_states is not None:
        hidden_states = hidden_states.repeat_interleave(topk, dim=0)
​
    tree_info = (
        topk_p.unsqueeze(1), # (b, 1, topk)
        topk_index, # (b, topk)
        torch.arange(-1, topk, dtype=torch.long, device=input_ids.device)
            .expand(topk_p.shape[0], -1), # expand 避免 repeat 一次分配
    )
    return input_ids, hidden_states, topk_p, tree_info
​
​
@torch.compile(dynamic=True, disable=_is_npu)
def _select_top_k_tokens_later(
    i: int,
    topk_p: torch.Tensor,
    topk_index: torch.Tensor,
    hidden_states: torch.Tensor,
    scores: torch.Tensor,
    topk: int,
):
    # 后续步骤:结合历史 scores 和 topk_p 计算 expand_scores,再取 topk
    topk_sq = topk * topk
​
    expand_scores = scores.unsqueeze(2) * topk_p.view(-1, topk, topk)
    # (b, topk, 1) * (b, topk, topk) -> (b, topk, topk)
​
    topk_cs_p, topk_cs_index = fast_topk(
        expand_scores.flatten(start_dim=1), topk, dim=-1
    )
​
    topk_index = topk_index.view(-1, topk_sq)
    input_ids = torch.gather(topk_index, 1, topk_cs_index).flatten()
​
    if hidden_states.shape[0] > 0:
        flat_cs = topk_cs_index.flatten()
        batch_offsets = torch.arange(
            0, hidden_states.shape[0], step=topk, device=flat_cs.device
        )
        selected_input_index = flat_cs // topk + batch_offsets.repeat_interleave(topk)
        hidden_states = hidden_states[selected_input_index]
​
    tree_info = (
        expand_scores, # (b, topk, topk)
        topk_index, # (b, topk * topk)
        topk_cs_index + (topk_sq * (i - 1) + topk), # (b, topk)
    )
    return input_ids, hidden_states, topk_cs_p, tree_info
​
​
def select_top_k_tokens(
    i: int,
    topk_p: torch.Tensor,
    topk_index: torch.Tensor,
    hidden_states: torch.Tensor,
    scores: torch.Tensor,
    topk: int,
):
    # 轻量路由:根据步骤号分派到具体实现
    if i == 0:
        return _select_top_k_tokens_first(topk_p, topk_index, hidden_states, topk)
    return _select_top_k_tokens_later(i, topk_p, topk_index, hidden_states, scores, topk)

评论区精华

NPU 关联性和 profiling 数据 question

审查者 Qiaolin-Yu 问为什么与 NPU 相关,并要求提供 torch profiling 结果。

结论:未在评论区直接回答,但最终 PR 被批准,可能线下说明。 · 已解决

风险与影响

拆分函数和添加可选参数均保持向后兼容,行为一致。但需注意:_select_top_k_tokens_first移除了@torch.compile,对于简单操作性能无影响;若hidden_states为None时逻辑正确。缺少直接测试文件变更,可能回归风险未被覆盖。

直接影响使用speculative decoding的推理请求,CPU开销降低可能提升解码吞吐。对NPU后端同样优化。由于改动集中在核心推理路径,影响面中等,但优化幅度需profiling验证。

核心路径变更 缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论