Prhub

#24859 [Spec V1] Split draft-extend phase from `EagleDraftInput` into new `EagleDraftExtendInput`

原始 PR 作者 hnyls2002 合并时间 2026-05-10 16:07 文件变更 10 提交数 24 评论 1 代码增减 +327 / -251

执行摘要

拆分推测解码 V1 Draft/Extend 数据结构

PR body 指出:重构前 EagleDraftInput.hidden_states 在同一实例上由 draft 阶段的 [bs, hidden] 切换为 draft-extend 阶段的 [total_accepted, hidden]EagleVerifyOutputnext_draft_input 名不符实(实际包含 extend 数据),且携带 4 个仅用于 verify→extend 衔接的临时字段。这种阶段混淆增加了 attention backend 的特殊判断和 worker 中的维护负担。通过明确分离两种阶段的数据结构,使数据流更清晰且类型安全。

该 PR 值得精读,尤其是 eagle_info.pyfrozen_kv_mtp_info.py 中的数据结构设计。对于从事推测解码开发的工程师,可以学习如何通过类型拆分消除阶段混淆。PR body 中的“Looks confusing but is correct”部分对设计权衡有清晰解释,可作为代码注释的典范。建议在合并前或合并后补充 V2 对齐的 issue 跟踪。

讨论亮点

PR body 中作者主动解释了多处“看似混淆但正确”的细节,可视为设计讨论:

  • filter_batch/merge_batch 虽然 diff 显示被重写,但字节级对比与原来完全一致,仅是位置移动。
  • EagleDraftInput 仍保留 num_accepted_drafts/num_accepted_tokens 等字段,是因为 V2 Overlap Worker 仍复用同一实例跨阶段,这些字段会留在 V2 对齐时清理。
  • bonus_tokens 同时存在于两个 dataclass,但职责不同:kernel 写入 extend-input,worker 拷贝到下一轮 draft-input 供 draft forward 使用。
  • 在所有请求完成分支安装空的 EagleDraftInput 而非留用 EagleVerifyInput,是因为 merge_batch 只定义在 EagleDraftInput 上,空实例的 hidden_states is None 使下次迭代短路。
  • 非 CUDA Graph 路径下的 softmax + fast_topk 内联替换从 capture_for_decode 中提取,语义等价且避免修改即将丢弃的 EagleDraftExtendInput
    这些解释降低了代码审查成本,也体现了作者对隐式契约的理解。

实现拆解

  1. eagle_info.py 中新增 EagleDraftExtendInput dataclass,集中 extend 阶段全部字段(per-accept-token hidden_states、accept counts、input_ids、seq_lens、req_pool_indices 等),并将 prepare_extend_after_decodegenerate_attn_arg_prefillfilter_batchmerge_batch 等操作移入该类。同时精简 EagleDraftInput,只保留 draft 阶段必要字段(topk_p、topk_index、hidden_states[bs, h] 等),V2 专用字段以 Optional 保留并注释。
  2. 修改 EagleVerifyOutput,将 next_draft_input 替换为 draft_extend_input,将 4 个过渡字段(unfinished_accept_tokensseq_lens_for_draft_extendseq_lens_for_draft_extend_cpureq_pool_indices_for_draft_extend)直接归入 EagleDraftExtendInputverify 方法构造并返回 EagleDraftExtendInput 实例。
  3. 调整 worker 控制流:eagle_worker.pymulti_layer_eagle_worker.pyfrozen_kv_mtp_worker.py 中的 forward_batch_generation 在 draft 后安装 verify_inputbatch.spec_info,调用 self.verify(batch)(不再传 spec_info),然后从 verify_output.draft_extend_input 取出 extend 数据安装到 batch.spec_info,调用 forward_draft_extend_after_decode,该方法返回下一轮 EagleDraftInput,由调用者安装。当所有请求完成时安装一个空的 EagleDraftInput(capture_hidden_mode=LAST),确保下一轮 merge_batch 能正确处理(EagleVerifyInputmerge_batch)。
  4. frozen_kv_mtp_info.py 中新增 FrozenKVMTPDraftExtendInput 作为 EagleDraftExtendInput 的标记子类,并重命名转换函数 _to_frozen_kv_mtp_draft_extend_inputfrozen_kv_mtp_worker.pyforward_draft_extend_after_decode 改为从 batch.spec_info 读取 extend 输入,空闲时安装空输入。
  5. forward_batch_info.py 中,_pad_inputs_to_size 改成 getattr 守卫以兼容两个 draft 类的字段差异;spec_info.py 增加 SpecInputType.EAGLE_DRAFT_EXTENDFROZEN_KV_MTP_DRAFT_EXTEND,并确保它们被 is_draft_input() 覆盖。相关 cuda graph runner 更新导入。
文件 模块 状态 重要度
python/sglang/srt/speculative/eagle_info.py 推测解码 modified 8.93
python/sglang/srt/speculative/frozen_kv_mtp_info.py 推测解码 modified 7.96
python/sglang/srt/speculative/eagle_worker.py 推测解码 modified 7.54
python/sglang/srt/speculative/multi_layer_eagle_worker.py 推测解码 modified 7.44
python/sglang/srt/speculative/frozen_kv_mtp_worker.py 推测解码 modified 7.6
python/sglang/srt/model_executor/forward_batch_info.py 前向批处理 modified 5.87

关键符号

EagleDraftExtendInput.__init__ EagleDraftExtendInput.prepare_extend_after_decode EagleDraftExtendInput.generate_attn_arg_prefill EagleDraftExtendInput.filter_batch EagleDraftExtendInput.merge_batch EagleDraftInput.filter_batch EagleDraftInput.merge_batch EagleVerifyInput.verify FrozenKVMTPDraftExtendInput.__post_init__ FrozenKVMTPVerifyInput.verify _to_frozen_kv_mtp_draft_extend_input EagleWorker.forward_draft_extend_after_decode MultiLayerEagleWorker.forward_draft_extend_after_decode FrozenKVMTPWorker.forward_draft_extend_after_decode ForwardBatch._pad_inputs_to_size

关键源码片段

python/sglang/srt/speculative/frozen_kv_mtp_info.py core-logic

对应 Frozen-KV MTP 的数据结构:新增 `FrozenKVMTPDraftExtendInput` 子类,重命名转换函数,同步修改 `FrozenKVMTPVerifyInput.verify` 以返回扩展后的输入。

# frozen_kv_mtp_info.py (head) — 标记子类与转换函数@dataclass
class FrozenKVMTPDraftExtendInput(EagleDraftExtendInput):
    """Draft-extend input for Frozen-KV MTP. Tag-only subclass."""
    def __post_init__(self):
        SpecInput.__init__(self, SpecInputType.FROZEN_KV_MTP_DRAFT_EXTEND)@dataclass
class FrozenKVMTPVerifyInput(EagleVerifyInput):
    def verify(self, *args, **kwargs) -> EagleVerifyOutput:
        output = super().verify(*args, **kwargs)
        # Convert the extend input from EAGLE type to Frozen-KV MTP type
        output.draft_extend_input = _to_frozen_kv_mtp_draft_extend_input(
            output.draft_extend_input
        )
        return outputdef _to_frozen_kv_mtp_draft_extend_input(
    draft_extend_input: EagleDraftExtendInput,
) -> FrozenKVMTPDraftExtendInput:
    """Field-wise copy guard: skip if already the right type."""
    if isinstance(draft_extend_input, FrozenKVMTPDraftExtendInput):
        return draft_extend_input
    return FrozenKVMTPDraftExtendInput(
        **{
            field.name: getattr(draft_extend_input, field.name)
            for field in fields(EagleDraftExtendInput)
        }
    )

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

  • 核心路径变更:V1 推测解码三路 worker 均修改了 verify 和 forward_draft_extend_after_decode 接口,非 CUDA Graph 路径下用内联 softmax+fast_topk 代替 capture_for_decode,虽声明等价但仍需回归验证。
  • V2 兼容性:V2 Overlap Worker 仍使用旧接口(EagleDraftInput 保留 V2 字段),本次 PR 未对齐 V2,后续清理时需注意双向兼容。
  • 缺少测试覆盖:本次 PR 未附带新的单元测试或集成测试,依赖现有 CI(CI 标签 run-ci 已触发),但风险仍存。
  • 数据结构契约:_pad_inputs_to_size 使用 getattr 守卫,若未来在两个 dataclass 上增加同名字段但语义不同,可能导致静默错误。
  • 用户影响:无直接用户可见变化,推测解码行为应与之前一致(PR 声明无行为改变)。
  • 系统影响:清理了大量过渡字段,简化了 verify→extend 路由,降低了 speculative 代码维护复杂度。V2 对齐作为 follow-up,需协调统一方向。
  • 团队影响:开发者阅读 spec 代码更易理解阶段边界;该 PR 可作为重构教科书式的示例,体现数据结构分离消除隐式状态的思路。
核心路径变更 缺少测试覆盖 V2 兼容性未对齐

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论