执行摘要
拆分 top-k 选择函数减少 specdec CPU 开销
根据PR标题和修改内容,主要动机是减少speculative decoding过程中的CPU overhead,提升推理性能。
值得阅读,尤其是拆分torch.compile函数以减少编译开销的模式。开发者可参考此方法优化其他类似分支函数。
审查者Qiaolin-Yu询问为什么该PR与NPU相关,并要求提供torch profiling结果。作者未在评论区回应,但最终获得批准,可能在线下沟通。
根据PR标题和修改内容,主要动机是减少speculative decoding过程中的CPU overhead,提升推理性能。
值得阅读,尤其是拆分torch.compile函数以减少编译开销的模式。开发者可参考此方法优化其他类似分支函数。
审查者Qiaolin-Yu询问为什么该PR与NPU相关,并要求提供torch profiling结果。作者未在评论区回应,但最终获得批准,可能在线下沟通。
| 文件 | 模块 | 状态 | 重要度 |
|---|---|---|---|
python/sglang/srt/speculative/spec_utils.py |
推测解码 | modified | 8.11 |
python/sglang/srt/speculative/eagle_info_v2.py |
推测解码 | modified | 6.05 |
python/sglang/srt/mem_cache/allocator.py |
内存管理 | modified | 5.57 |
python/sglang/srt/hardware_backend/npu/allocator_npu.py |
内存管理 | modified | 5.62 |
python/sglang/srt/mem_cache/swa_memory_pool.py |
内存管理 | modified | 4.49 |
python/sglang/srt/mem_cache/hisparse_memory_pool.py |
内存管理 | modified | 3.95 |
python/sglang/srt/speculative/spec_utils.py
core-logic
核心变更文件,将 select_top_k_tokens 拆分为两个独立函数以减少 torch.compile 开销。
def _select_top_k_tokens_first(
topk_p: torch.Tensor,
topk_index: torch.Tensor,
hidden_states: Optional[torch.Tensor],
topk: int,
):
# 首步选择:直接将 topk_index flatten 作为候选 token ID
# 并 repeat_interleave hidden_states 以匹配 topk 展开
input_ids = topk_index.flatten()
if hidden_states is not None:
hidden_states = hidden_states.repeat_interleave(topk, dim=0)
tree_info = (
topk_p.unsqueeze(1), # (b, 1, topk)
topk_index, # (b, topk)
torch.arange(-1, topk, dtype=torch.long, device=input_ids.device)
.expand(topk_p.shape[0], -1), # expand 避免 repeat 一次分配
)
return input_ids, hidden_states, topk_p, tree_info
@torch.compile(dynamic=True, disable=_is_npu)
def _select_top_k_tokens_later(
i: int,
topk_p: torch.Tensor,
topk_index: torch.Tensor,
hidden_states: torch.Tensor,
scores: torch.Tensor,
topk: int,
):
# 后续步骤:结合历史 scores 和 topk_p 计算 expand_scores,再取 topk
topk_sq = topk * topk
expand_scores = scores.unsqueeze(2) * topk_p.view(-1, topk, topk)
# (b, topk, 1) * (b, topk, topk) -> (b, topk, topk)
topk_cs_p, topk_cs_index = fast_topk(
expand_scores.flatten(start_dim=1), topk, dim=-1
)
topk_index = topk_index.view(-1, topk_sq)
input_ids = torch.gather(topk_index, 1, topk_cs_index).flatten()
if hidden_states.shape[0] > 0:
flat_cs = topk_cs_index.flatten()
batch_offsets = torch.arange(
0, hidden_states.shape[0], step=topk, device=flat_cs.device
)
selected_input_index = flat_cs // topk + batch_offsets.repeat_interleave(topk)
hidden_states = hidden_states[selected_input_index]
tree_info = (
expand_scores, # (b, topk, topk)
topk_index, # (b, topk * topk)
topk_cs_index + (topk_sq * (i - 1) + topk), # (b, topk)
)
return input_ids, hidden_states, topk_cs_p, tree_info
def select_top_k_tokens(
i: int,
topk_p: torch.Tensor,
topk_index: torch.Tensor,
hidden_states: torch.Tensor,
scores: torch.Tensor,
topk: int,
):
# 轻量路由:根据步骤号分派到具体实现
if i == 0:
return _select_top_k_tokens_first(topk_p, topk_index, hidden_states, topk)
return _select_top_k_tokens_later(i, topk_p, topk_index, hidden_states, scores, topk)
审查者 Qiaolin-Yu 问为什么与 NPU 相关,并要求提供 torch profiling 结果。
结论:未在评论区直接回答,但最终 PR 被批准,可能线下说明。 · 已解决
拆分函数和添加可选参数均保持向后兼容,行为一致。但需注意:_select_top_k_tokens_first移除了@torch.compile,对于简单操作性能无影响;若hidden_states为None时逻辑正确。缺少直接测试文件变更,可能回归风险未被覆盖。
直接影响使用speculative decoding的推理请求,CPU开销降低可能提升解码吞吐。对NPU后端同样优化。由于改动集中在核心推理路径,影响面中等,但优化幅度需profiling验证。
当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。
参与讨论