Prhub

#25774 drop output ids

原始 PR 作者 hnyls2002 合并时间 2026-05-20 08:50 文件变更 6 提交数 11 评论 1 代码增减 +32 / -31

执行摘要

移除 ScheduleBatch.output_ids,统一使用 input_ids 作为跨迭代桥梁

Eliminate the SB.output_ids field — a forward-stream-produced tensor that was leaking onto the schedule stream. input_ids becomes the unified cross-iter bridge: end-of-forward sets it (placeholder for overlap, real tokens for non-overlap); next iter's resolve_future consumes it. (来自 PR body)

本 PR 是一次关键的内部重构,值得 SRT 调度层开发者精读。尤其注意重叠模式下占位符的使用方式,以及后续 PR #25819 将继续清理 prepare_for_extend 中的冗余移位。

讨论亮点

Review 中 gemini-code-assist[bot] 指出在 merge_batch 方法中,对 input_ids 的合并逻辑可能存在问题:如果一个批次的 input_idsNone 而另一个不是,会导致数据丢失或形状不匹配。该评论未得到作者公开回应,PR 即已合并。建议后续跟踪此边界情况的修复。

实现拆解

  1. 移除属性定义:在 ScheduleBatch 类(schedule_batch.py)中删除 output_ids: torch.Tensor 属性定义,将所有内部引用(prepare_for_decodefilter_batchmerge_batch)中对 output_ids 的访问和赋值替换为对 input_ids 的访问。
  2. 重写 run_batch 出口:在 scheduler.pyrun_batch 方法中,不再使用临时变量 future_indices_or_next_token_ids,而是直接在各个分支中将 batch.input_ids 设置为占位符(重叠模式下为 -future_indices.indices,非重叠模式下为 next_token_ids.to(torch.int64))。注意 PP 和 DLLM 返回类型可能不是 Tensor,需用 isinstance 保护。
  3. 调整 prepare_for_decode 入口:在 ScheduleBatch.prepare_for_decode 中删除原来的 self.input_ids = self.output_ids.to(torch.int64); self.output_ids = None 赋值,因为 input_ids 已在上一轮 run_batch 结束时设置完成。penalizer 累积改为直接使用 self.input_ids
  4. 适配解聚解码路径:在 decode_schedule_batch_mixin.pyprocess_prebuilt 中,不再将 self.output_ids 作为列表积累并转为张量,而是使用本地列表 last_tokens 构造 last_tokens_tensor,然后仅在非 spec 场景下赋值给 self.input_ids;spec 场景下 last_tokens_tensor 作为 bonus_tokens 传入 spec_info
  5. 同步 MLX 和 PP 适配:在 MLX 后端的 _finalize 中,将 pending.batch_copy.output_ids = result.next_token_ids 改为 pending.batch_copy.input_ids = result.next_token_ids;在 PP mixin 的 _pp_prep_batch_result 中将 batch.output_ids = pp_outputs["next_token_ids"] 改为 batch.input_ids = pp_outputs["next_token_ids"].to(torch.int64)
  6. 更新 benchmark 脚本:在 bench_one_batch.pydecode 函数中,将 batch.output_ids = input_token_ids 改为 batch.input_ids = input_token_ids.to(torch.int64)
文件 模块 状态 重要度
python/sglang/srt/managers/scheduler.py 调度器 modified 6.65
python/sglang/srt/managers/schedule_batch.py 批处理 modified 6.12
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py 解聚解码 modified 6.25
python/sglang/srt/hardware_backend/mlx/scheduler_mixin.py MLX 后端 modified 4.7
python/sglang/srt/managers/scheduler_pp_mixin.py 流水线并行 modified 4.7
python/sglang/bench_one_batch.py 基准工具 modified 4.09

关键符号

prepare_for_decode process_prebuilt run_batch _build_hisparse_decode_batch filter_batch merge_batch _finalize _pp_prep_batch_result decode

关键源码片段

python/sglang/srt/managers/scheduler.py core-logic

核心调度循环,run_batch 出口处修改 input_ids 设置逻辑,是本次变更的主战场。

# === scheduler.py (extract from run_batch) ===
if self.enable_overlap:
    with self._overlap_forward_isolation(batch):
        bs = len(batch.seq_lens)
        future_indices = self.future_map.alloc_future_indices(bs)
        with self.forward_stream_ctx:
            self.forward_stream.wait_stream(self.schedule_stream)
            self.future_map.resolve_future(batch)
            batch_result = self.model_worker.forward_batch_generation(batch)
            if batch_result.extra_keep_alive_refs:
                self.batch_record_buf[self.batch_record_ct].extend(
                    batch_result.extra_keep_alive_refs
                )
            batch_result.copy_done = self.device_module.Event()
            if batch_result.delay_sample_func is None:
                self.future_map.store_to_map(future_indices, batch_result)
                batch_result.copy_to_cpu(...)
            else:
                batch_result.future_indices = future_indices
​
    # 占位符:下一轮 resolve_future 会从 token_ids_buf 中
    # 通过负索引取真实 token
    batch.input_ids = -future_indices.indices
else:
    batch_result = self.model_worker.forward_batch_generation(batch, **kwargs)
    if isinstance(batch_result.next_token_ids, torch.Tensor):
        # PP 中间 rank 返回 None,DLLM 返回 per-req list
        batch.input_ids = batch_result.next_token_ids.to(torch.int64)
python/sglang/srt/managers/schedule_batch.py core-logic

定义 ScheduleBatch 类,移除 output_ids 属性并调整 prepare_for_decode、filter_batch、merge_batch 中的引用。

# === schedule_batch.py (class ScheduleBatch) ===
class ScheduleBatch:
    # ... 其他属性 ...
    input_ids: torch.Tensor = None # shape: [b], int64, 跨迭代 token 桥
    # output_ids 属性被完全移除
​
    def prepare_for_decode(self):
        # ...
        if self.sampling_info.penalizer_orchestrator.is_required:
            # 之前使用 self.output_ids.to(torch.int64),现直接使用
            # 已经在 run_batch 末尾设置好的 input_ids
            self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                self.input_ids
            )
        # 原 self.input_ids = self.output_ids.to(torch.int64); self.output_ids = None 被移除
        # ...
​
    def filter_batch(self, ...):
        # ...
        if self.input_ids is not None:
            self.input_ids = self.input_ids[keep_indices_device]
        # 原 self.output_ids 相关的处理已全部替换
​
    def merge_batch(self, other):
        # ...
        if self.input_ids is not None and other.input_ids is not None:
            self.input_ids = torch.cat([self.input_ids, other.input_ids])
        # 注意:边界条件可能导致 data loss,详见风险分析
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py dependency-wiring

process_prebuilt 方法中不再设置 self.output_ids,改用本地列表并赋值给 self.input_ids,同时避免在 spec 路径下覆盖 prefill 的 input_ids。

# === decode_schedule_batch_mixin.py ===
def process_prebuilt(self, server_args, future_map):
    """Assign the buffered last input id to schedule batch"""
    last_tokens: List[int] = []
    for req in self.reqs:
        last_tokens.append(req.output_ids[-1])
        # grammar 等处理 ...
    last_tokens_tensor = torch.tensor(
        last_tokens, dtype=torch.int64, device=self.device
    )
    if self.spec_algorithm.is_eagle():
        # spec 路径:last_tokens_tensor 作为 bonus_tokens 传入
        spec_info.prepare_for_extend(self) # 保留 prefill prompt 的 input_ids
        spec_info.bonus_tokens = last_tokens_tensor
        # ...
    else:
        # 非 spec:直接作为下一轮 decode 的输入
        self.input_ids = last_tokens_tensor

评论区精华

merge_batch 中 input_ids 合并边界条件 正确性

Review 指出如果 self.input_ids 为 None 而 other.input_ids 不为 None,则 other 的 input_ids 会丢失;如果 self.input_ids 不为 None 而 other.input_ids 为 None,则 batch 会扩展 reqs 但 input_ids 形状不匹配,导致后续运行时错误。建议更安全地处理 None 情况。

结论:PR 已合并,但该问题未公开回复,也未见后续修复。 · unresolved

风险与影响

主要风险:1)output_ids 属性被移除,任何外部代码或未合并分支中直接访问 batch.output_ids 的地方将报错。2)merge_batchinput_ids 的合并逻辑不完善,可能导致 data loss(当 self.input_ids 为 None 时)。3)PP 中间 rank 返回 None,DLLM 返回 list,这些场景下 batch.input_ids 不会被设置,依赖其非 None 的后续操作可能出错。4)重叠模式下使用负索引作为占位符,若 resolve_future 未能正确解析,可能导致错误的 token 输入。

影响范围限于 SRT 调度层内部,对最终用户透明。但所有涉及 ScheduleBatch 构造、拷贝、合并的路径(包括 speculative decoding、disaggregation、PP、MLX)均需确保使用新的 input_ids 语义。影响程度中等,变更机械但需全面审查调用点。性能方面,减少了跨流张量传输和一次 to(torch.int64) 转换,预期在重叠模式下有微小提升。

merge_batch 边界条件待修复 output_ids 属性移除需审计所有调用方 重叠模式占位符依赖 resolve_future 正确实现

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论