Prhub

#25516 refactor: remove ModelWorkerBatch indirection

原始 PR 作者 hnyls2002 合并时间 2026-05-18 09:05 文件变更 21 提交数 22 评论 3 代码增减 +431 / -560

执行摘要

移除 ModelWorkerBatch 中间层,简化批量数据流

消除长期存在的冗余中间层 ModelWorkerBatch,简化数据流和降低维护成本。参考 schedule_batch.py 中原有的 TODO 注释:"ModelWorkerBatch seems a bit redundant and we consider removing it in the future."

值得精读,尤其关注 _overlap_forward_isolation 的上下文管理器设计、一次性覆盖模式以及跨流张量保活策略。可作为架构重构的参考案例。

讨论亮点
  1. 调度器快照的完整性:Review 指出 dataclasses.fields(batch) 只能捕获定义为 dataclass 字段的属性,可能遗漏动态添加的临界张量。作者后续通过 attr_snapshot 使用 dataclasses.fields 明确只覆盖声明字段,未处理动态属性(视为已知限制)。
  2. record_stream 冗余:Review 指出 eagle_worker_v2.pymulti_layer_eagle_worker_v2.py 中手动的 record_stream 调用可能冗余,因为 _record_sb_tensors_on_stream 已覆盖相同字段。建议将调用移动到 rebind 之后以集中处理。作者后续提交将相关逻辑提取到 spec_utils.py,但未删除手动调用。

实现拆解

  1. 移除 ModelWorkerBatch 类:删除 schedule_batch.py 中的 ModelWorkerBatch 数据类和 get_model_worker_batch 方法,清理所有引用。
  2. 迁移一次性覆盖字段:将 seq_lens_cpu_cachecapture_hidden_modereturn_hidden_states_before_norm 三个字段移到 ScheduleBatch,由 ForwardBatch.init_new 消费后自动重置。
  3. 引入快照恢复机制:在 scheduler.py 中添加 _overlap_forward_isolation 上下文管理器,在 Spec V2 重叠前向期间对 ScheduleBatch 进行完整快照,并在完成后恢复,避免 V2 的原地修改泄漏到下一次调度。同时替换 sampling_info 为前向专用副本,防止多次 init_new 重复累积惩罚。
  4. 添加流记录防御:在 spec_utils.py 中新增 record_stream_for_v2_verifyrecord_stream_each,并在 Spec V2 工作器中调用,确保跨流张量(如 input_idsout_cache_loc)在验证阶段不被释放。
  5. 全面适配调用方:修改所有 Spec 工作器(eagle_worker_v2.pymulti_layer_eagle_worker_v2.pyeagle_worker.pymulti_layer_eagle_worker.pydflash_worker.py)和硬件后端(mlx/tp_worker.py)的方法签名,将 ModelWorkerBatch 替换为 ScheduleBatch
  6. 调整工具函数overlap_utils.py 中的 resolve_futuretp_worker.py 中的 forward_batch_embedding 等适配新参数类型。
文件 模块 状态 重要度
python/sglang/srt/managers/scheduler.py 调度器 modified 8.32
python/sglang/srt/managers/schedule_batch.py 批处理 modified 8.24
python/sglang/srt/model_executor/forward_batch_info.py 前向批信息 modified 8.24
python/sglang/srt/speculative/eagle_worker_v2.py 推测解码 modified 8.21
python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py 推测解码 modified 7.86
python/sglang/srt/hardware_backend/mlx/tp_worker.py MLX 硬件 modified 7.55
python/sglang/srt/speculative/spec_utils.py 推测工具 modified 7.55
python/sglang/srt/managers/tp_worker.py 工作进程 modified 7.1
python/sglang/srt/managers/overlap_utils.py 重叠执行 modified 6.58
python/sglang/srt/speculative/dflash_worker.py 推测解码 modified 6.79
python/sglang/srt/speculative/multi_layer_eagle_worker.py 推测解码 modified 6.32
python/sglang/srt/speculative/eagle_worker.py 推测解码 modified 6.31

关键符号

_overlap_forward_isolation record_batch_in_overlap ForwardBatch.init_new record_stream_for_v2_verify record_stream_each draft verify forward_batch_generation

关键源码片段

python/sglang/srt/managers/scheduler.py core-logic

核心调度器,实现 batch 快照恢复和跨流引用的 ring buffer 机制

# python/sglang/srt/managers/scheduler.py@contextmanager
def _overlap_forward_isolation(self, batch: ScheduleBatch):
    """
    使 ScheduleBatch 在一次 overlap forward 中具有事务性:
    1. 快照 V2 字段,以便 forward 后恢复。
    2. 替换 sampling_info 为前向专用副本,避免多次 init_new 重复累积惩罚。
    3. 将 (batch, snapshot) 固定到 batch_record_buf 中 2 个迭代周期,
       确保 GPU 张量在 forward stream 完成前不被释放。
    """
    # 1. 快照:仅对 spec V2 完整保存所有 dataclass 字段
    snapshot_v2_full = batch.is_spec_v2
    sched_snapshot = (
        {f.name: getattr(batch, f.name) for f in dataclasses.fields(batch)}
        if snapshot_v2_full
        else None
    )
    sched_sampling_info = batch.sampling_info
​
    # 2. 替换 sampling_info 为前向副本(orchestrator=None,共享已累计惩罚缓冲区)
    if sched_sampling_info is not None:
        batch.sampling_info = sched_sampling_info.copy_for_forward()
​
    # 3. 将 (batch, snapshot) 固定到 ring buffer,确保张量存活
    # 注意:必须在 sampling_info 替换之后执行,以固定前向副本
    self.record_batch_in_overlap(batch)
​
    try:
        yield
    finally:
        # 恢复快照
        if sched_snapshot:
            for k, v in sched_snapshot.items():
                setattr(batch, k, v)
        # 恢复 sampling_info
        batch.sampling_info = sched_sampling_info
python/sglang/srt/managers/schedule_batch.py core-logic

删除 ModelWorkerBatch 数据结构和转换方法,新增一次性覆盖字段

# python/sglang/srt/managers/schedule_batch.py# 数据流注释更新:
# ScheduleBatch -> ForwardBatch
# ForwardBatch 由 ScheduleBatch 通过 ForwardBatch.init_new 直接构造。@dataclass
class ScheduleBatch(...):
    # ... 原有字段 ...
​
    # 全新:一次性前向覆盖,init_new 消费后自动重置为默认值
    seq_lens_cpu_cache: torch.Tensor = None
    capture_hidden_mode: Optional[CaptureHiddenMode] = None
    return_hidden_states_before_norm: bool = False
​
    # ... 其他字段保持不变 ...
python/sglang/srt/model_executor/forward_batch_info.py data-contract

ForwardBatch.init_new 直接消费 ScheduleBatch,消费一次性覆盖字段

# python/sglang/srt/model_executor/forward_batch_info.py@classmethod
def init_new(
    cls,
    batch: ScheduleBatch, # 现在直接接收 ScheduleBatch
    model_runner: ModelRunner,
):
    # 消费一次性覆盖字段并重置
    capture_hidden_mode = batch.capture_hidden_mode
    batch.capture_hidden_mode = None
    seq_lens_cpu_cache = batch.seq_lens_cpu_cache
    batch.seq_lens_cpu_cache = None
    return_hidden_states_before_norm = batch.return_hidden_states_before_norm
    batch.return_hidden_states_before_norm = False
​
    # 若未覆盖,则从 batch 的 spec_info 等推导默认值
    if capture_hidden_mode is None:
        if batch.return_hidden_states:
            capture_hidden_mode = CaptureHiddenMode.FULL
        elif batch.spec_info is not None:
            capture_hidden_mode = getattr(
                batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
            )
        else:
            capture_hidden_mode = CaptureHiddenMode.NULL
​
    # ... 后续构造 ForwardBatch 的逻辑 ...

评论区精华

调度器快照可能遗漏动态属性 正确性

Review 指出 dataclasses.fields(batch) 仅捕获 dataclass 字段,若存在动态添加的属性则不会包含在快照中,可能导致 GPU 张量提前释放。

结论:作者未回应,但该设计为已知限制,适用于当前所有字段均为 dataclass 字段的假设。 · 已解决

eagle_worker_v2 中 record_stream 冗余 性能

Review 指出手动的 record_stream 调用已被 _record_sb_tensors_on_stream 覆盖,建议将辅助调用移到 rebind 之后以集中清理。

结论:作者后续提交将辅助函数移至 spec_utils,但未删除手动调用,保留了双重记录。 · 已解决

multi_layer_eagle_worker_v2 中 record_stream 冗余 性能

与上一条类似,建议简化 record_stream 调用。

结论:同上,最终保留双重调用以防 rebind 后张量不同。 · 已解决

风险与影响

  1. 快照遗漏风险_overlap_forward_isolation 仅快照 dataclass 字段,若未来因需求在运行时添加非字段属性,可能导致 GPU 张量提前释放。
  2. record_stream 双重记录:手动的 record_stream_record_sb_tensors_on_stream 可能重复记录,带来微小性能开销但无功能影响。
  3. 状态管理复杂度:快照恢复和一次性覆盖机制增加了 ScheduleBatch 的隐式状态转换,可能引入难以调试的临时错误。

用户侧:无直接影响,功能等价。
系统侧:减少一次数据拷贝和转换,轻微提升性能;降低后续开发的心智负担。
团队侧:消除长期 TODO,简化代码维护;新的状态管理模式需要团队成员理解并遵循一次性覆盖契约。

快照可能遗漏动态属性 record_stream 双重记录带来微小开销 新增状态管理增加复杂性

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论