Prhub

#25037 spec: STANDALONE skips hidden_states end-to-end (Optional schema + None-safe consumers)

原始 PR 作者 hnyls2002 合并时间 2026-05-13 03:27 文件变更 12 提交数 8 评论 1 代码增减 +280 / -79

执行摘要

STANDALONE 投机解码跳过 hidden_states 端到端捕获

STANDALONE 投机解码使用 vanilla LLM 作为 draft 模型,架构上从不读取 spec_info.hidden_states。原有的实现仍然会捕获、分配、拷贝 hidden_states,不仅浪费 GPU 内存,还在 target/draft hidden_size 不一致时暴露形状不匹配 bug。本 PR 通过架构级不变量——STANDALONE 模式下 spec_info.hidden_states 恒为 None——实现端到端跳过,统一解决 #21434 和 #14563 中修复过的类似问题。

值得精读,尤其是 Optional schema 的设计和 None 守卫的分布模式,可作为类似架构变更的参考。重点关注 eagle_info.py 中的 classmethod 返回类型变更和每个 producer 站点的 capture_hidden_mode 三元表达式。

讨论亮点

PR 没有公开的 review 讨论。作者在 body 中详细说明了设计动机和变更范围,并引用了 #21434 和 #14563 作为相关修复。在 merge commit 中处理了与 #25038 重命名冲突(accepted_indices → accept_indices),确保了 None 守卫与新命名共存。

实现拆解

  1. Schema 调整:在 EagleDraftInput 类中将 hidden_states 字段改为 Optional[torch.Tensor]hidden_size_for()dtype_for() 返回 Optional[int]Optional[torch.dtype],对于 STANDALONE 返回 Nonecreate_idle_input() 根据 hidden_size 是否为 None 决定是否创建空张量。
  2. Producer 侧统一 NULL 模式:在 eagle_worker.pymulti_layer_eagle_worker.pyeagle_worker_v2.pyeagle_info_v2.pycuda_graph_runner.py 等文件的 24 个 capture_hidden_mode 赋值点,使用三元表达式 CaptureHiddenMode.NULL if self.speculative_algorithm.is_standalone() else 原值,确保 STANDALONE 模式下不会触发 hidden_states 捕获。
  3. Consumer 侧 None 守卫:在所有读取/切片/拷贝 spec_info.hidden_stateslogits_output.hidden_states 的地方添加 if ... is not None 条件,包括 eagle_info.pyverify() 中两次构建 EagleDraftExtendInput 时的切片、eagle_draft_cuda_graph_runner.py 中 buffer 分配与 replay 拷贝、eagle_worker.pyverify() 中隐藏状态切片、spec_utils.py 中的 shape 检查等。同时 FutureMapspec_need_hidden_states() 对 STANDALONE 返回 False 从而跳过 hidden_states_buf 的存储和加载。
  4. 测试配套:在 test_standalone_speculative_decoding.py 中新增 test_radix_attention 方法,通过 radix-tree 压力测试验证 Optional schema 与 None-safe 组合的正确性。
文件 模块 状态 重要度
python/sglang/srt/speculative/eagle_info.py 投机解码 modified 7.85
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py CUDA 图 modified 6.79
python/sglang/srt/speculative/eagle_worker.py 投机解码 modified 6.79
python/sglang/srt/speculative/multi_layer_eagle_worker.py 投机解码 modified 6.61
test/registered/spec/test_standalone_speculative_decoding.py 测试 modified 4.78

关键符号

EagleDraftInput.hidden_size_for EagleDraftInput.dtype_for EagleDraftInput.create_idle_input EagleDraftInput.prepare_for_extend EAGLEWorker.forward_target_extend EAGLEWorker._draft_preprocess_idle EAGLEWorker.draft EAGLEWorker.verify MultiLayerEAGLEWorker.forward_target_extend MultiLayerEAGLEWorker.draft MultiLayerEAGLEWorker.forward_draft_extend_after_decode EAGLEDraftCudaGraphRunner.__init__ EAGLEDraftCudaGraphRunner.replay EAGLEDraftCudaGraphRunner.capture_one_batch_size

关键源码片段

python/sglang/srt/speculative/eagle_info.py core-logic

核心数据类 EagleDraftInput 的 hidden_states 字段改为 Optional,hidden_size_for/dtype_for 返回 Optional,create_idle_input 条件创建,是本次变更的 schema 基础。

# python/sglang/srt/speculative/eagle_info.pyclass EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
    topk_p: torch.Tensor = None
    topk_index: torch.Tensor = None
    # None when the spec algorithm's draft doesn't read hidden_states
    # (e.g., STANDALONE — vanilla LLM draft).
    hidden_states: Optional[torch.Tensor] = None
    capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
    # ...
​
    @classmethod
    def hidden_size_for(cls, worker) -> Optional[int]:
        """Decode-phase `hidden_states` width. Returns None when the draft
        architecture doesn't consume the field (e.g., STANDALONE)."""
        if worker.speculative_algorithm.is_standalone():
            return None
        return _draft_runner_of(worker).model_config.spec_hidden_size
​
    @classmethod
    def dtype_for(cls, worker) -> Optional[torch.dtype]:
        if worker.speculative_algorithm.is_standalone():
            return None
        return _draft_runner_of(worker).model_config.dtype
​
    @classmethod
    def create_idle_input(
        cls,
        device: torch.device,
        hidden_size: Optional[int],
        dtype: Optional[torch.dtype],
        topk: int,
        capture_hidden_mode: CaptureHiddenMode,
    ):
        return cls(
            bonus_tokens=torch.empty((0,), device=device, dtype=torch.int32),
            hidden_states=(
                torch.empty((0, hidden_size), device=device, dtype=dtype)
                if hidden_size is not None
                else None
            ),
            topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
            topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
            capture_hidden_mode=capture_hidden_mode,
        )
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py dependency-wiring

CUDA graph runner 的 buffer 分配和 replay 需要适配 Optional hidden_states,是本变更在 CUDA graph 路径上的关键实现。

# python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py@dataclass
class EagleDraftInputBuffers(ForwardInputBuffers):
    input_ids: torch.Tensor
    req_pool_indices: torch.Tensor
    out_cache_loc: torch.Tensor
    positions: torch.Tensor
    mrope_positions: torch.Tensor
    seq_lens: torch.Tensor
    seq_lens_cpu: torch.Tensor
    extend_seq_lens: torch.Tensor
    topk_p: torch.Tensor
    topk_index: torch.Tensor
    hidden_states: Optional[torch.Tensor] # None when STANDALONE
    global_num_tokens_gpu: Optional[torch.Tensor]
    global_num_tokens_for_logprob_gpu: Optional[torch.Tensor]# In __init__:
_hidden_size = EagleDraftInput.hidden_size_for(self.eagle_worker)
hidden_states = (
    torch.zeros(
        (self.max_bs, _hidden_size),
        dtype=EagleDraftInput.dtype_for(self.eagle_worker),
    )
    if _hidden_size is not None
    else None
)# In replay:
if (
    buffers.hidden_states is not None
    and forward_batch.spec_info.hidden_states is not None
):
    buffers.hidden_states[:raw_bs].copy_(
        forward_batch.spec_info.hidden_states
    )

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

  1. 遗漏 None 守卫:变更点分散在 24 个 producer 站点和多个 consumer 站点,可能存在遗漏的 hidden_states 访问路径,导致 STANDALONE 模式下的 AttributeError 或类型错误。
  2. 回归影响:对于非 STANDALONE 算法(EAGLE、EAGLE3、FROZEN_KV_MTP、多层 EAGLE),应完全无行为变化,但需要依赖已有测试覆盖。
  3. merge 冲突:merge commit 解决了与 #25038 的命名冲突,但手动解决可能引入错误,需确认 guard 逻辑正确。
  4. 测试覆盖:新增的 test_radix_attention 仅覆盖 radix 场景,缺少对其他 consumer 路径的专项测试。

对用户:STANDALONE 模式用户将显著减少 GPU 内存使用并避免 hidden_size 不匹配导致的崩溃;其他模式用户无影响。对系统:减少了不必要的 hidden_states 传播计算和显存占用,轻微降低 decode 延迟。对团队:架构不变量更清晰,后续添加新算法时需遵循 Optional 模式。

24 处修改点可能遗漏 None 守卫 STANDALONE 模式回归风险 需确保 EAGLE 模式无影响 merge 冲突手动解决风险

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论