执行摘要
- 一句话:将 Frozen-KV MTP 辅助种子步骤融合到捕获的草稿 CUDA 图中
- 推荐动作:该 PR 值得精读,特别是了解如何将 eager forward 步骤融合到现有的 CUDA 图中以减少 kernel launch 开销。设计思路(将第一轮迭代纳入循环)可推广到其他类似场景。
功能与动机
Frozen-KV MTP 在捕获的循环草稿图之前运行一个单 token eager assistant seed forward。由于 seed 和 recurrent 迭代共享相同的 seq_lens - 1 rope 位置(相对于冻结的目标 KV),分离它们只会导致每次 decode 额外的 launch 和同步。将 seed 融入图中可减少开销。
实现拆解
- 修改
_run_assistant_seed_step(frozen_kv_mtp_worker.py):该方法不再执行模型 forward,而是将 seed 输入(bonus_tokens、hidden_states、占位的 topk_p/topk_index 零张量)存放到 batch.spec_info 中,供后续捕获的图使用。原本的模型 forward 逻辑(设置 forward mode、_set_positions、_init_frozen_kv_metadata 等)被移除,改为在 draft_forward 中作为迭代 0 执行。
- 重构
draft_forward(frozen_kv_mtp_worker.py):将 seed 迭代作为 recurrent loop 的第一次迭代(iter 0)集成到捕获的图中。不再需要单独处理 topk==1 的快捷路径(已删除)。
- 扩展
FrozenKVMTPInputBuffers(frozen_kv_mtp_cuda_graph_runner.py):新增 bonus_tokens 缓冲区,用于向图传递 seed token。相应地,在 capture_one_batch_size 和 replay 中复制 bonus_tokens,并移除对 topk_p/topk_index 的复制(现在由 seed iter 自身产生)。
- 添加 profiling 支持(
frozen_kv_mtp_cuda_graph_runner.py):在 replay 方法的图执行周围添加 torch.profiler.record_function span,便于性能分析。
- 更新单元测试 fixture(
speculative_draft_runner.py):_make_dense_frozen_kv_mtp_draft_inputs 和 _make_dense_frozen_kv_mtp_forward_batch 改为提供 bonus_tokens 而非 topk_p/topk_index,与新的输入约定对齐。
关键文件:
python/sglang/srt/speculative/frozen_kv_mtp_worker.py(模块 推测解码;类别 source;类型 core-logic;符号 _run_assistant_seed_step, draft_forward, forward_batch_generation): 核心改动所在,修改了 assistant seed 步骤的实现方式和 draft forward 的迭代逻辑。
python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py(模块 CUDA 图;类别 source;类型 core-logic;符号 FrozenKVMTPInputBuffers, capture_one_batch_size, replay): CUDA 图运行器扩展了输入缓冲区,新增 bonus_tokens 字段,并调整了 replay 逻辑以支持 seed iter 的集成。
python/sglang/test/kits/attention_unittest/runner_modes/speculative_draft_runner.py(模块 测试工具;类别 test;类型 test-coverage;符号 _make_dense_frozen_kv_mtp_draft_inputs, _make_dense_frozen_kv_mtp_forward_batch): 更新了测试 fixture 以匹配新的输入约定,确保单元测试覆盖新的逻辑路径。
关键符号:_run_assistant_seed_step, draft_forward, replay, capture_one_batch_size, _make_dense_frozen_kv_mtp_draft_inputs, _make_dense_frozen_kv_mtp_forward_batch
关键源码片段
python/sglang/srt/speculative/frozen_kv_mtp_worker.py
核心改动所在,修改了 assistant seed 步骤的实现方式和 draft forward 的迭代逻辑。
# python/sglang/srt/speculative/frozen_kv_mtp_worker.py
def _run_assistant_seed_step(
self,
batch: ScheduleBatch,
last_token_ids: torch.Tensor,
last_hidden_states: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor] = None,
mm_input_embeds: Optional[torch.Tensor] = None,
draft_input: Optional[FrozenKVMTPDraftInput] = None,
) -> None:
"""Stash seed inputs on ``batch.spec_info``; the forward runs inside
the captured draft graph (see ``draft_forward``'s seed iter)."""
del seq_lens_cpu, mm_input_embeds, draft_input # unused after folding
if batch.forward_mode.is_idle() or last_token_ids.numel() == 0:
batch.spec_info = FrozenKVMTPDraftInput.create_idle_input(
device=batch.device,
hidden_size=self._recurrent_hidden_size,
dtype=self.model_config.dtype,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
return
stashed = FrozenKVMTPDraftInput()
stashed.bonus_tokens = last_token_ids.to(torch.int64)
stashed.hidden_states = last_hidden_states
# Real-shaped zeros so inherited `filter_batch`/`merge_batch` can slice
# them between iters; overwritten by the captured seed iter.
bs = last_token_ids.shape[0]
device = last_token_ids.device
stashed.topk_p = torch.zeros(
(bs, self.topk), device=device, dtype=torch.float32
)
stashed.topk_index = torch.zeros(
(bs, self.topk), device=device, dtype=torch.int64
)
stashed.capture_hidden_mode = CaptureHiddenMode.LAST
stashed.num_tokens_per_req = 1
stashed.num_tokens_for_logprob_per_req = 1
batch.spec_info = stashed
python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py
CUDA 图运行器扩展了输入缓冲区,新增 bonus_tokens 字段,并调整了 replay 逻辑以支持 seed iter 的集成。
# python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py
@dataclass
class FrozenKVMTPInputBuffers(ForwardInputBuffers):
req_pool_indices: torch.Tensor
positions: torch.Tensor
mrope_positions: torch.Tensor
seq_lens: torch.Tensor
seq_lens_cpu: torch.Tensor
topk_p: torch.Tensor
topk_index: torch.Tensor
hidden_states: torch.Tensor
# Consumed by the captured seed iter; see `FrozenKVMTPWorker.draft_forward`.
bonus_tokens: torch.Tensor
global_num_tokens_gpu: Optional[torch.Tensor]
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor]
class FrozenKVMTPCudaGraphRunner:
...
def __init__(self, frozen_kv_mtp_worker: FrozenKVMTPWorker):
...
with torch.device(model_runner.device):
...
bonus_tokens = torch.zeros((self.max_bs,), dtype=torch.int64)
...
self.buffers = FrozenKVMTPInputBuffers(
...,
bonus_tokens=bonus_tokens,
...,
)
def replay(self, forward_batch: ForwardBatch):
...
# `topk_p`/`topk_index` are produced by the captured seed iter.
buffers.bonus_tokens[:raw_bs].copy_(forward_batch.spec_info.bonus_tokens)
...
# NVTX span: the graph bypasses `model_runner.forward`'s record_function.
span_name = f"step[DRAFT_LOOP raw_bs={raw_bs} bs={bs} topk={self.topk}]"
if torch.autograd._profiler_enabled():
with torch.profiler.record_function(span_name):
self._replay()
else:
self._replay()
评论区精华
gemini-code-assist[bot] 提出了两项优化建议:
- 清理未用参数(medium 优先级):
_run_assistant_seed_step 的形参 seq_lens_cpu、mm_input_embeds、draft_input 不再使用,建议从签名中移除并更新调用者,以保持 API 整洁。
-
使用 torch.empty 替代 torch.zeros(medium 优先级):占位用的 topk_p 和 topk_index 会被 seed iter 覆写,建议用 torch.empty 避免不必要的 GPU 初始化开销。
以上建议未被采纳,PR 在主要 reviewer 批准后合并。
-
清理 _run_assistant_seed_step 未用参数 (design): 未采纳;PR 作者选择保留参数并使用 del 语句避免警告。
- 使用 torch.empty 替代 torch.zeros 避免初始化开销 (performance): 未采纳;PR 保持 torch.zeros 以确保 filter_batch/merge_batch 在覆写前有明确定义的值。
风险与影响
关联脉络
- PR #26981 Revert "Support spec v2 tree drafting (eagle topk>1) with page_size==1": 都是 speculative-decoding 模块的修改,且涉及 topk 处理,间接相关。
- PR #23273 [NVIDIA] [GDN] Enable FlashInfer MTP verify on SM100+ (Blackwell): 同样是 MTP(Multi-Token Prediction)相关的性能优化,共享类似的技术栈。
参与讨论