Prhub

#25539 [Spec] `FrozenKVMTP` fold assistant seed into captured draft graph

原始 PR 作者 kpham-sgl 合并时间 2026-06-02 13:27 文件变更 3 提交数 8 评论 8 代码增减 +97 / -64

执行摘要

将 Frozen-KV MTP 辅助种子步骤融合到捕获的草稿 CUDA 图中

Frozen-KV MTP 在捕获的循环草稿图之前运行一个单 token eager assistant seed forward。由于 seed 和 recurrent 迭代共享相同的 seq_lens - 1 rope 位置(相对于冻结的目标 KV),分离它们只会导致每次 decode 额外的 launch 和同步。将 seed 融入图中可减少开销。

该 PR 值得精读,特别是了解如何将 eager forward 步骤融合到现有的 CUDA 图中以减少 kernel launch 开销。设计思路(将第一轮迭代纳入循环)可推广到其他类似场景。

讨论亮点

gemini-code-assist[bot] 提出了两项优化建议:

  • 清理未用参数(medium 优先级):_run_assistant_seed_step 的形参 seq_lens_cpumm_input_embedsdraft_input 不再使用,建议从签名中移除并更新调用者,以保持 API 整洁。
  • 使用 torch.empty 替代 torch.zeros(medium 优先级):占位用的 topk_ptopk_index 会被 seed iter 覆写,建议用 torch.empty 避免不必要的 GPU 初始化开销。
    以上建议未被采纳,PR 在主要 reviewer 批准后合并。

实现拆解

  1. 修改 _run_assistant_seed_stepfrozen_kv_mtp_worker.py:该方法不再执行模型 forward,而是将 seed 输入(bonus_tokenshidden_states、占位的 topk_p/topk_index 零张量)存放到 batch.spec_info 中,供后续捕获的图使用。原本的模型 forward 逻辑(设置 forward mode、_set_positions_init_frozen_kv_metadata 等)被移除,改为在 draft_forward 中作为迭代 0 执行。
  2. 重构 draft_forwardfrozen_kv_mtp_worker.py:将 seed 迭代作为 recurrent loop 的第一次迭代(iter 0)集成到捕获的图中。不再需要单独处理 topk==1 的快捷路径(已删除)。
  3. 扩展 FrozenKVMTPInputBuffersfrozen_kv_mtp_cuda_graph_runner.py:新增 bonus_tokens 缓冲区,用于向图传递 seed token。相应地,在 capture_one_batch_sizereplay 中复制 bonus_tokens,并移除对 topk_p/topk_index 的复制(现在由 seed iter 自身产生)。
  4. 添加 profiling 支持(frozen_kv_mtp_cuda_graph_runner.py:在 replay 方法的图执行周围添加 torch.profiler.record_function span,便于性能分析。
  5. 更新单元测试 fixture(speculative_draft_runner.py_make_dense_frozen_kv_mtp_draft_inputs_make_dense_frozen_kv_mtp_forward_batch 改为提供 bonus_tokens 而非 topk_p/topk_index,与新的输入约定对齐。
文件 模块 状态 重要度
python/sglang/srt/speculative/frozen_kv_mtp_worker.py 推测解码 modified 7.34
python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py CUDA 图 modified 6.39
python/sglang/test/kits/attention_unittest/runner_modes/speculative_draft_runner.py 测试工具 modified 5.01

关键符号

_run_assistant_seed_step draft_forward replay capture_one_batch_size _make_dense_frozen_kv_mtp_draft_inputs _make_dense_frozen_kv_mtp_forward_batch

关键源码片段

python/sglang/srt/speculative/frozen_kv_mtp_worker.py core-logic

核心改动所在,修改了 assistant seed 步骤的实现方式和 draft forward 的迭代逻辑。

# python/sglang/srt/speculative/frozen_kv_mtp_worker.pydef _run_assistant_seed_step(
    self,
    batch: ScheduleBatch,
    last_token_ids: torch.Tensor,
    last_hidden_states: torch.Tensor,
    seq_lens_cpu: Optional[torch.Tensor] = None,
    mm_input_embeds: Optional[torch.Tensor] = None,
    draft_input: Optional[FrozenKVMTPDraftInput] = None,
) -> None:
    """Stash seed inputs on ``batch.spec_info``; the forward runs inside
    the captured draft graph (see ``draft_forward``'s seed iter)."""
    del seq_lens_cpu, mm_input_embeds, draft_input # unused after folding
​
    if batch.forward_mode.is_idle() or last_token_ids.numel() == 0:
        batch.spec_info = FrozenKVMTPDraftInput.create_idle_input(
            device=batch.device,
            hidden_size=self._recurrent_hidden_size,
            dtype=self.model_config.dtype,
            topk=self.topk,
            capture_hidden_mode=CaptureHiddenMode.LAST,
        )
        return
​
    stashed = FrozenKVMTPDraftInput()
    stashed.bonus_tokens = last_token_ids.to(torch.int64)
    stashed.hidden_states = last_hidden_states
    # Real-shaped zeros so inherited `filter_batch`/`merge_batch` can slice
    # them between iters; overwritten by the captured seed iter.
    bs = last_token_ids.shape[0]
    device = last_token_ids.device
    stashed.topk_p = torch.zeros(
        (bs, self.topk), device=device, dtype=torch.float32
    )
    stashed.topk_index = torch.zeros(
        (bs, self.topk), device=device, dtype=torch.int64
    )
    stashed.capture_hidden_mode = CaptureHiddenMode.LAST
    stashed.num_tokens_per_req = 1
    stashed.num_tokens_for_logprob_per_req = 1
    batch.spec_info = stashed
python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py core-logic

CUDA 图运行器扩展了输入缓冲区,新增 bonus_tokens 字段,并调整了 replay 逻辑以支持 seed iter 的集成。

# python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py@dataclass
class FrozenKVMTPInputBuffers(ForwardInputBuffers):
    req_pool_indices: torch.Tensor
    positions: torch.Tensor
    mrope_positions: torch.Tensor
    seq_lens: torch.Tensor
    seq_lens_cpu: torch.Tensor
    topk_p: torch.Tensor
    topk_index: torch.Tensor
    hidden_states: torch.Tensor
    # Consumed by the captured seed iter; see `FrozenKVMTPWorker.draft_forward`.
    bonus_tokens: torch.Tensor
    global_num_tokens_gpu: Optional[torch.Tensor]
    global_num_tokens_for_logprob_gpu: Optional[torch.Tensor]
​
​
class FrozenKVMTPCudaGraphRunner:
    ...
    def __init__(self, frozen_kv_mtp_worker: FrozenKVMTPWorker):
        ...
        with torch.device(model_runner.device):
            ...
            bonus_tokens = torch.zeros((self.max_bs,), dtype=torch.int64)
            ...
        self.buffers = FrozenKVMTPInputBuffers(
            ...,
            bonus_tokens=bonus_tokens,
            ...,
        )
​
    def replay(self, forward_batch: ForwardBatch):
        ...
        # `topk_p`/`topk_index` are produced by the captured seed iter.
        buffers.bonus_tokens[:raw_bs].copy_(forward_batch.spec_info.bonus_tokens)
        ...
        # NVTX span: the graph bypasses `model_runner.forward`'s record_function.
        span_name = f"step[DRAFT_LOOP raw_bs={raw_bs} bs={bs} topk={self.topk}]"
        if torch.autograd._profiler_enabled():
            with torch.profiler.record_function(span_name):
                self._replay()
        else:
            self._replay()

评论区精华

清理 _run_assistant_seed_step 未用参数 设计

gemini-code-assist[bot] 指出 seq_lens_cpu、mm_input_embeds、draft_input 参数不再使用,建议从签名中移除并更新调用者以保持 API 整洁。

结论:未采纳;PR 作者选择保留参数并使用 del 语句避免警告。 · unresolved

使用 torch.empty 替代 torch.zeros 避免初始化开销 性能

gemini-code-assist[bot] 建议将占位的 topk_p 和 topk_index 用 torch.empty 创建,因为它们会在 seed iter 中被覆写。

结论:未采纳;PR 保持 torch.zeros 以确保 filter_batch/merge_batch 在覆写前有明确定义的值。 · unresolved

风险与影响

  1. seed iter 正确性风险:将原本独立的 eager forward 嵌入 CUDA 图可能引入边界错误,尤其是位置设置和 frozen KV metadata 初始化。PR 通过确保 draft_forward 中的 seed iter 设置正确的位置和 forward mode 来缓解。
  2. topk==1 路径回归:commit 历史显示曾恢复 topk==1 路径(restore code path for topk==1),说明该场景需要特别处理,最终版本保留了兼容性。
  3. 未采纳的代码清理:未移除未用参数可能导致后续维护困惑,但不会引发运行时错误。
  4. 测试覆盖:单元测试 fixture 已更新,但未新增针对 topk>1 或不同 batch size 的额外测试。

用户侧:使用 Frozen-KV MTP(如 Gemma4 31B)的推理可获得性能提升(PR body 显示延迟从约 6.25ms 降至 5.45ms)。功能性无变化,精度保持(GSM8K 评分通过阈值)。
系统侧:CUDA 图捕获和回放逻辑略作调整,但兼容现有调度和 KV 缓存路径。
团队侧:减少了代码分支(移除 topk==1 快捷路径),降低了维护成本。

核心路径变更 性能优化需精度验证 未清理未用参数

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论