Prhub

#26800 Fix the EAGLE chunked-prefill next-token chain (#26329)

原始 PR 作者 fzyzcjy 合并时间 2026-05-31 09:48 文件变更 4 提交数 1 评论 2 代码增减 +37 / -1

执行摘要

修复 EAGLE chunked prefill draft 链发散 bug

修复EAGLE chunked prefill场景下draft链与目标模型发散的问题。当EAGLE使用chunked prefill时,非最终chunk的tail token应使用下一个prompt token而不是已验证的next token,否则draft chain会偏离目标,导致生成质量下降。

建议精读本次变更,理解chunked prefill与推测解码交互的细节。值得关注的设计决策是:通过在ScheduleBatch中存储chunked_req_next_prompt_token,将chunked状态的查询与draft worker解耦。此外,建议尽快将多层EAGLE worker中的TODO落实为实际修复,并补充端到端测试。

讨论亮点

Code review bot(gemini-code-assist[bot])提出了两个改进建议:

  1. eagle_utils.py中直接使用_eagle_prefill_tail_tokens:建议移除apply_eagle_prefill_input_rotation中的TODO,直接调用新函数以确保所有EAGLE prefill旋转路径的一致性。
  2. 在多层EAGLE worker中应用相同修复:建议在multi_layer_eagle_worker_v2.py中导入并使用_eagle_prefill_tail_tokens,避免循环依赖可考虑局部导入,并传递正确的tail tokens给rotate_input_ids_triton
    目前这些建议未被采纳,仍以TODO形式留待后续。

实现拆解

  1. ScheduleBatch中添加chunked_req_next_prompt_token字段python/sglang/srt/managers/schedule_batch.py):新增_compute_chunked_req_next_prompt_token函数,通过检查chunked_req.fill_ids长度与origin_input_ids长度,确定下一个prompt token。在init_new中调用该函数初始化新字段。
  2. 新增_eagle_prefill_tail_tokens工具函数python/sglang/srt/speculative/eagle_utils.py):该函数接收batchnext_token_ids,对于chunked request,用chunked_req_next_prompt_token替换对应的tail token,否则保持原值。同时保留了原apply_eagle_prefill_input_rotation中的TODO以便后续统一。
  3. 在单层EAGLE draft worker中应用python/sglang/srt/speculative/eagle_worker_v2.py):修改_draft_extend_for_prefill,调用_eagle_prefill_tail_tokens获取正确的tail tokens,替换原来的next_token_ids[i]
  4. 在多层EAGLE draft worker中添加TODOpython/sglang/srt/speculative/multi_layer_eagle_worker_v2.py):由于多层worker使用rotate_input_ids_triton且存在同样的chain divergence问题,添加TODO注释引用PR#26329,待后续修复。
  5. 测试配套:本次变更没有包含直接对应的测试文件,但修复本身为bugfix,涉及核心推测解码路径。
文件 模块 状态 重要度
python/sglang/srt/speculative/eagle_utils.py 推测解码 modified 6.71
python/sglang/srt/managers/schedule_batch.py 调度器 modified 6.65
python/sglang/srt/speculative/eagle_worker_v2.py 推测解码 modified 5.22
python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py 推测解码 modified 4.3

关键符号

_eagle_prefill_tail_tokens _compute_chunked_req_next_prompt_token

关键源码片段

python/sglang/srt/speculative/eagle_utils.py core-logic

新增 `_eagle_prefill_tail_tokens` 核心函数,实现 chunked-aware 的 tail token 替换逻辑;同时保留了原始函数中的 TODO。

def _eagle_prefill_tail_tokens(
    batch: ScheduleBatch, next_token_ids: torch.Tensor
) -> torch.Tensor:
    """Per-seq tail token for EAGLE prefill rotation; uses next prompt token for
    non-final chunks (chunked-prefill chain consistency, see PR #26329)."""
    # 默认使用 verified next token 作为 tail token
    tail_tokens = next_token_ids.to(batch.input_ids.dtype)
    # 如果 batch 有 chunked request 的 next prompt token
    next_prompt_token = batch.chunked_req_next_prompt_token
    if next_prompt_token is not None:
        # 遍历找到 chunked request,替换其 tail token 为 next prompt token
        for i, r in enumerate(batch.reqs):
            if r is batch.chunked_req:
                tail_tokens = tail_tokens.clone()
                tail_tokens[i] = next_prompt_token
                break
    return tail_tokens
python/sglang/srt/managers/schedule_batch.py data-contract

添加 `chunked_req_next_prompt_token` 字段和 `_compute_chunked_req_next_prompt_token` 函数,用于计算并存储 chunked request 的下一个 prompt token。

def _compute_chunked_req_next_prompt_token(
    chunked_req: Optional[Req],
) -> Optional[int]:
    """根据 chunked request 的 fill_ids 计算下一个 prompt token.
    如果 chunked_req 为 None 或已经完成预填充,返回 None."""
    if chunked_req is None:
        return None
    fill_len = len(chunked_req.fill_ids)
    if fill_len >= len(chunked_req.origin_input_ids):
        return None
    return int(chunked_req.origin_input_ids[fill_len])# 在 ScheduleBatch 类中新增字段 :
@dataclasses.dataclass
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
    # ...
    chunked_req_next_prompt_token: Optional[int] = None
    # ...
​
    @classmethod
    def init_new(cls, ...):
        # ...
        batch = cls(
            # ...
            chunked_req=chunked_req,
            chunked_req_next_prompt_token=_compute_chunked_req_next_prompt_token(chunked_req),
            # ...
        )
python/sglang/srt/speculative/eagle_worker_v2.py core-logic

修改 `_draft_extend_for_prefill` 方法,使用 `_eagle_prefill_tail_tokens` 获取正确的 tail tokens,修复 chunked prefill 下的 draft 链发散。

# 导入新函数
from sglang.srt.speculative.eagle_utils import (
    TreeMaskMode,
    _eagle_prefill_tail_tokens, # 新增
    build_tree_kernel_efficient,
    per_step_draft_out_cache_loc,
)# 在 _draft_extend_for_prefill 中使用 :
def _draft_extend_for_prefill(self, batch, target_hidden_states, next_token_ids, ...):
    if not batch.forward_mode.is_idle():
        # Chunked-prefill-aware tail tokens (see PR #26329).
        tail_tokens = _eagle_prefill_tail_tokens(batch, next_token_ids)
        pt = 0
        for i, extend_len in enumerate(batch.extend_lens):
            input_ids = batch.input_ids[pt : pt + extend_len]
            batch.input_ids[pt : pt + extend_len] = torch.cat(
                (input_ids[1:], tail_tokens[i].reshape(1))
            )
            pt += extend_len

评论区精华

在 eagle_utils.py 中直接使用 _eagle_prefill_tail_tokens 设计

gemini-code-assist[bot] 建议直接移除 TODO,在 apply_eagle_prefill_input_rotation 中也使用 _eagle_prefill_tail_tokens 以确保一致性。

结论:未被采纳,保留了 TODO。 · 待处理

在多层 EAGLE worker 中应用相同修复 设计

gemini-code-assist[bot] 建议在 multi_layer_eagle_worker_v2.py 中也导入并使用 _eagle_prefill_tail_tokens,避免循环依赖可考虑局部导入。

结论:未被采纳,仅添加了 TODO。 · 待处理

风险与影响

  1. 回归风险(低):仅修改了chunked prefill下的draft路径,非chunked场景行为不变。但缺少针对chunked prefill + EAGLE的专项测试,可能遗漏边界情况。
  2. 多层EAGLE worker未同步修复multi_layer_eagle_worker_v2.py中仅添加了TODO,实际bug仍然存在,如果用户使用多层EAGLE + chunked prefill,draft链仍然会发散。
  3. 性能影响(极低):新函数_eagle_prefill_tail_tokens引入了额外的循环和克隆操作,但仅在chunked request时执行,开销可忽略。

影响范围:仅限于使用EAGLE推测解码且启用chunked prefill的场景。修复确保draft链与目标模型一致,提升生成质量。影响程度:对于受影响用户,这是一个正确的bug修复;对于不使用chunked prefill或EAGLE的用户,无影响。团队影响:需要在多层EAGLE worker中跟进修复。

缺少测试覆盖 多层 EAGLE worker 未同步修复

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论