执行摘要
- 一句话:修复EAGLE chunked prefill draft链发散bug
- 推荐动作:建议精读本次变更,理解chunked prefill与推测解码交互的细节。值得关注的设计决策是:通过在
ScheduleBatch中存储chunked_req_next_prompt_token,将chunked状态的查询与draft worker解耦。此外,建议尽快将多层EAGLE worker中的TODO落实为实际修复,并补充端到端测试。
功能与动机
修复EAGLE chunked prefill场景下draft链与目标模型发散的问题。当EAGLE使用chunked prefill时,非最终chunk的tail token应使用下一个prompt token而不是已验证的next token,否则draft chain会偏离目标,导致生成质量下降。
实现拆解
- 在
ScheduleBatch中添加chunked_req_next_prompt_token字段(python/sglang/srt/managers/schedule_batch.py):新增_compute_chunked_req_next_prompt_token函数,通过检查chunked_req.fill_ids长度与origin_input_ids长度,确定下一个prompt token。在init_new中调用该函数初始化新字段。
- 新增
_eagle_prefill_tail_tokens工具函数(python/sglang/srt/speculative/eagle_utils.py):该函数接收batch和next_token_ids,对于chunked request,用chunked_req_next_prompt_token替换对应的tail token,否则保持原值。同时保留了原apply_eagle_prefill_input_rotation中的TODO以便后续统一。
- 在单层EAGLE draft worker中应用(
python/sglang/srt/speculative/eagle_worker_v2.py):修改_draft_extend_for_prefill,调用_eagle_prefill_tail_tokens获取正确的tail tokens,替换原来的next_token_ids[i]。
- 在多层EAGLE draft worker中添加TODO(
python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py):由于多层worker使用rotate_input_ids_triton且存在同样的chain divergence问题,添加TODO注释引用PR#26329,待后续修复。
- 测试配套:本次变更没有包含直接对应的测试文件,但修复本身为bugfix,涉及核心推测解码路径。
关键文件:
python/sglang/srt/speculative/eagle_utils.py(模块 推测解码;类别 source;类型 core-logic;符号 _eagle_prefill_tail_tokens): 新增_eagle_prefill_tail_tokens核心函数,实现chunked-aware的tail token替换逻辑;同时保留了原始函数中的TODO。
python/sglang/srt/managers/schedule_batch.py(模块 调度器;类别 source;类型 data-contract;符号 _compute_chunked_req_next_prompt_token): 添加chunked_req_next_prompt_token字段和_compute_chunked_req_next_prompt_token函数,用于计算并存储chunked request的下一个prompt token。
python/sglang/srt/speculative/eagle_worker_v2.py(模块 推测解码;类别 source;类型 core-logic): 修改_draft_extend_for_prefill方法,使用_eagle_prefill_tail_tokens获取正确的tail tokens,修复chunked prefill下的draft链发散。
python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py(模块 推测解码;类别 source;类型 core-logic): 添加TODO注释,指出存在相同的chunked prefill chain divergence问题,待后续修复。
关键符号:_eagle_prefill_tail_tokens, _compute_chunked_req_next_prompt_token
关键源码片段
python/sglang/srt/speculative/eagle_utils.py
新增_eagle_prefill_tail_tokens核心函数,实现chunked-aware的tail token替换逻辑;同时保留了原始函数中的TODO。
def _eagle_prefill_tail_tokens(
batch: ScheduleBatch, next_token_ids: torch.Tensor
) -> torch.Tensor:
"""Per-seq tail token for EAGLE prefill rotation; uses next prompt token for
non-final chunks (chunked-prefill chain consistency, see PR #26329)."""
# 默认使用 verified next token 作为 tail token
tail_tokens = next_token_ids.to(batch.input_ids.dtype)
# 如果 batch 有 chunked request 的 next prompt token
next_prompt_token = batch.chunked_req_next_prompt_token
if next_prompt_token is not None:
# 遍历找到 chunked request,替换其 tail token 为 next prompt token
for i, r in enumerate(batch.reqs):
if r is batch.chunked_req:
tail_tokens = tail_tokens.clone()
tail_tokens[i] = next_prompt_token
break
return tail_tokens
python/sglang/srt/managers/schedule_batch.py
添加chunked_req_next_prompt_token字段和_compute_chunked_req_next_prompt_token函数,用于计算并存储chunked request的下一个prompt token。
def _compute_chunked_req_next_prompt_token(
chunked_req: Optional[Req],
) -> Optional[int]:
"""根据 chunked request 的 fill_ids 计算下一个 prompt token.
如果 chunked_req 为 None 或已经完成预填充,返回 None."""
if chunked_req is None:
return None
fill_len = len(chunked_req.fill_ids)
if fill_len >= len(chunked_req.origin_input_ids):
return None
return int(chunked_req.origin_input_ids[fill_len])
# 在 ScheduleBatch 类中新增字段 :
@dataclasses.dataclass
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# ...
chunked_req_next_prompt_token: Optional[int] = None
# ...
@classmethod
def init_new(cls, ...):
# ...
batch = cls(
# ...
chunked_req=chunked_req,
chunked_req_next_prompt_token=_compute_chunked_req_next_prompt_token(chunked_req),
# ...
)
python/sglang/srt/speculative/eagle_worker_v2.py
修改_draft_extend_for_prefill方法,使用_eagle_prefill_tail_tokens获取正确的tail tokens,修复chunked prefill下的draft链发散。
# 导入新函数
from sglang.srt.speculative.eagle_utils import (
TreeMaskMode,
_eagle_prefill_tail_tokens, # 新增
build_tree_kernel_efficient,
per_step_draft_out_cache_loc,
)
# 在 _draft_extend_for_prefill 中使用 :
def _draft_extend_for_prefill(self, batch, target_hidden_states, next_token_ids, ...):
if not batch.forward_mode.is_idle():
# Chunked-prefill-aware tail tokens (see PR #26329).
tail_tokens = _eagle_prefill_tail_tokens(batch, next_token_ids)
pt = 0
for i, extend_len in enumerate(batch.extend_lens):
input_ids = batch.input_ids[pt : pt + extend_len]
batch.input_ids[pt : pt + extend_len] = torch.cat(
(input_ids[1:], tail_tokens[i].reshape(1))
)
pt += extend_len
评论区精华
Code review bot(gemini-code-assist[bot])提出了两个改进建议:
- 在
eagle_utils.py中直接使用_eagle_prefill_tail_tokens:建议移除apply_eagle_prefill_input_rotation中的TODO,直接调用新函数以确保所有EAGLE prefill旋转路径的一致性。
- 在多层EAGLE worker中应用相同修复:建议在
multi_layer_eagle_worker_v2.py中导入并使用_eagle_prefill_tail_tokens,避免循环依赖可考虑局部导入,并传递正确的tail tokens给rotate_input_ids_triton。
目前这些建议未被采纳,仍以TODO形式留待后续。
- 在 eagle_utils.py 中直接使用 _eagle_prefill_tail_tokens (design): 未被采纳,保留了 TODO。
- 在多层 EAGLE worker 中应用相同修复 (design): 未被采纳,仅添加了 TODO。
风险与影响
- 风险:
- 回归风险(低):仅修改了chunked prefill下的draft路径,非chunked场景行为不变。但缺少针对chunked prefill + EAGLE的专项测试,可能遗漏边界情况。
- 多层EAGLE worker未同步修复:
multi_layer_eagle_worker_v2.py中仅添加了TODO,实际bug仍然存在,如果用户使用多层EAGLE + chunked prefill,draft链仍然会发散。
- 性能影响(极低):新函数
_eagle_prefill_tail_tokens引入了额外的循环和克隆操作,但仅在chunked request时执行,开销可忽略。
- 影响:影响范围:仅限于使用EAGLE推测解码且启用chunked prefill的场景。修复确保draft链与目标模型一致,提升生成质量。影响程度:对于受影响用户,这是一个正确的bug修复;对于不使用chunked prefill或EAGLE的用户,无影响。团队影响:需要在多层EAGLE worker中跟进修复。
- 风险标记:缺少测试覆盖, 多层EAGLE worker未同步修复
关联脉络
- PR #26329 Original issue/PR for the chain divergence fix: 本PR直接关联issue/PR #26329,多个TODO和注释中明确引用。
参与讨论