执行摘要
- 一句话:重构投机解码中 hidden_size 路由方式
- 推荐动作:建议精读,特别是
EagleDraftExtendInput.hidden_size_for 的实现及 eagle_use_aux_hidden_state 的语义。可作为 speculative decoding 代码可维护性提升的范例。
功能与动机
PR #24926 引入了 EagleDraftInput/EagleDraftExtendInput 的 hidden_size_for 和 dtype_for 类方法,但仍有部分站点(如 idle 输入创建处)保留着内联的 hidden_size 分支。本 PR 延续该清理工作,消除重复的 hidden_size 计算逻辑,确保 EAGLE-3 的 aux hidden state 扩展只在实际启用时生效,避免潜在的大小不匹配。
实现拆解
- multi_layer_eagle_worker.py: 在
__init__ 中新增 eagle_use_aux_hidden_state 属性,从 HF config 中读取 use_aux_hidden_state(默认为 True);在 forward_draft_extend_after_decode 中将内联的 hidden_size 分支替换为 EagleDraftExtendInput.hidden_size_for(self) 和 .dtype_for(self)。
- multi_layer_eagle_draft_extend_cuda_graph_runner.py: 在 CUDA Graph buffer 初始化中,将
hidden_states 张量的大小从 self.model_runner.model_config.hidden_size 改为 EagleDraftExtendInput.hidden_size_for(self.eagle_worker),dtype 同理。
- multi_layer_eagle_worker_v2.py: 在
__init__ 中暴露 self.draft_runner 属性(指向 draft_runner_list[0])以统一 _draft_runner_of 的解析;在 idle 输入创建处替换为 EagleDraftInput.hidden_size_for(self.draft_worker)。
- eagle_worker_v2.py: 同样替换 idle 输入中的内联
spec_hidden_size 和 dtype 为 classmethod 调用。
- 测试配套: 无新增测试文件;PR 作者通过
/rerun-test 手动触发了相关 speculative 和 eagle 测试用例并全部通过。
关键文件:
python/sglang/srt/speculative/multi_layer_eagle_worker.py(模块 投机解码;类别 source;类型 core-logic): 核心变更文件:添加 eagle_use_aux_hidden_state 属性,替换内联 hidden_size 计算为 classmethod 调用。
python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py(模块 投机解码;类别 source;类型 core-logic): CUDA Graph buffer 中 hidden_states 张量大小改为通过 classmethod 推导,确保与 worker 端一致。
python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py(模块 投机解码;类别 source;类型 core-logic): 暴露 draft_runner 属性以便 classmethod 统一访问;替换 idle 输入中的内联 hidden_size 和 dtype 为 classmethod。
python/sglang/srt/speculative/eagle_worker_v2.py(模块 投机解码;类别 source;类型 core-logic): 同样替换 idle 输入中的内联 hidden_size 和 dtype 为 classmethod,实现与 MLE V2 相同的统一。
关键符号:未识别
关键源码片段
python/sglang/srt/speculative/multi_layer_eagle_worker.py
核心变更文件:添加 eagle_use_aux_hidden_state 属性,替换内联 hidden_size 计算为 classmethod 调用。
# 从 hf_config 中提取 eagle_config,判断是否使用 aux hidden state
self.eagle_use_aux_hidden_state = False
if self.speculative_algorithm.is_eagle3():
eagle_config = getattr(
self.model_runner.model_config.hf_config, "eagle_config", {}
)
# 默认为 True,保留与之前无条件 *3 一致的行为
self.eagle_use_aux_hidden_state = eagle_config.get(
"use_aux_hidden_state", True
)
# ... 后续在 idle 输入创建时,不再手动计算 hidden_size
# 改为统一通过类方法路由,且 classmethod 内部已经处理了 *3 的逻辑
draft_extend_input = EagleDraftExtendInput.create_idle_input(
device=self.device,
hidden_size=EagleDraftExtendInput.hidden_size_for(self),
dtype=EagleDraftExtendInput.dtype_for(self),
capture_hidden_mode=CaptureHiddenMode.LAST,
)
python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py
CUDA Graph buffer 中 hidden_states 张量大小改为通过 classmethod 推导,确保与 worker 端一致。
# 之前:
# hidden_states = torch.zeros(
# (self.max_num_token, self.model_runner.model_config.hidden_size),
# dtype=self.model_runner.dtype,
# )
# 之后:统一通过 classmethod 路由,与 worker 端 idle 输入创建保持一致
hidden_states = torch.zeros(
(
self.max_num_token,
EagleDraftExtendInput.hidden_size_for(self.eagle_worker),
),
dtype=EagleDraftExtendInput.dtype_for(self.eagle_worker),
)
python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py
暴露 draft_runner 属性以便 classmethod 统一访问;替换 idle 输入中的内联 hidden_size 和 dtype 为 classmethod。
# 在 __init__ 中,暴露出 draft_runner 属性,
# 供 _draft_runner_of 函数解析 EagleDraftInput 的 hidden_size_for 等类方法。
self.draft_runner: ModelRunner = self.draft_runner_list[0]
# 在 forward_batch_generation 中创建 idle 输入时:
model_worker_batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=EagleDraftInput.hidden_size_for(self.draft_worker),
dtype=EagleDraftInput.dtype_for(self.draft_worker),
topk=self.topk * self.speculative_num_steps,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
评论区精华
无 review 评论或讨论线程。
风险与影响
- 风险:主要风险在于
eagle_use_aux_hidden_state 默认值为 True,与之前无条件 * 3 行为一致,但若某些 EAGLE-3 模型显式设置 use_aux_hidden_state=False 且 model_config.hidden_size * 3 大于实际所需,可能导致 hidden state buffer 过小或过大,引发内存错误或形状不匹配。需确保所有调用 hidden_size_for 的位置都与 CUDA Graph buffer 大小保持一致。
- 影响:影响范围集中在 speculative decoding 模块的 4 个核心文件中,无外部接口或行为变化。对用户透明,但为后续 EAGLE-3 特性(如禁用 aux hidden state)铺平了道路。改动量小,风险可控。
- 风险标记:缺少测试覆盖
关联脉络
- PR #24926 [Spec] Introduce EagleDraft{,Extend}Input classmethods for hidden_size/dtype: 本 PR 是对 PR #24926 的延续,将剩余的 inline hidden_size 计算也统一使用 classmethod 路由。
- PR #25010 [Spec] Remove dead kernel params; fix stale comment in
trtllm_mla: 同一系列 speculative 模块清理工作,虽涉及不同文件,但属相同代码改进方向。
参与讨论