Prhub

#26797 [core] Compute token_type_ids in ForwardBatch.init_new

原始 PR 作者 hnyls2002 合并时间 2026-05-31 15:54 文件变更 2 提交数 4 评论 2 代码增减 +40 / -42

执行摘要

将 token_type_ids 计算挪入 ForwardBatch

PR body 明确说明:"Move the cross-encoder token_type_ids device-tensor build out of ScheduleBatch.prepare_for_extend into ForwardBatch.init_new, so the H2D runs on the forward stream and SB no longer carries a forward-only field." 这是一个明确的职责分离和流管理优化。

值得精读。该 PR 展示了如何通过重构保持 ScheduleBatch 的职责纯洁性(只做调度编排),将前向相关的设备张量构建下沉到 ForwardBatch,是流管理和职责分离的良好实践。但需注意 review 中提出的性能建议尚未解决,可在后续跟进。

讨论亮点

gemini-code-assist[bot] 在 review 中提出性能优化建议:原来使用 sum(list_of_lists, []) 展平列表是 O(N^2) 的反模式,建议改为嵌套列表推导式(O(N) 复杂度)。该建议未被作者采纳或回复,PR 就已合并。

实现拆解

  1. forward_batch_info.pyinit_new 中移除了 token_type_ids=batch.token_type_ids 的传递,改为从 batch.reqs 直接收集并构建设备张量。
  2. 将原有的 _maybe_init_prefill_only 方法重命名为 _maybe_init_non_generation_fields,扩大了其职责:除了原有的 dimensionsreturn_pooled_hidden_statesmulti_item_delimiter_indices 外,新增了对 token_type_ids 的处理。
  3. schedule_batch.pyScheduleBatch 类定义中移除了 token_type_ids 字段声明,并在 prepare_for_extend 中删除了对应的收集、构造张量以及赋值操作。清理后的代码中,prepare_for_extend 不再持有或构建该张量。
  4. 为支持 token_type_ids 的 pinned memory 分配,在 forward_batch_info.py 的 import 中新增了 is_pin_memory_available 的导入。
文件 模块 状态 重要度
python/sglang/srt/model_executor/forward_batch_info.py 前向批处理 modified 8.16
python/sglang/srt/managers/schedule_batch.py 调度批处理 modified 5.97

关键符号

_maybe_init_prefill_only _maybe_init_non_generation_fields prepare_for_extend init_new

关键源码片段

python/sglang/srt/model_executor/forward_batch_info.py core-logic

核心变更文件,重命名并扩展 `_maybe_init_prefill_only` 为 `_maybe_init_non_generation_fields`,在其中新增 `token_type_ids` 设备张量构建逻辑。

# python/sglang/srt/model_executor/forward_batch_info.pydef _maybe_init_non_generation_fields(self, batch: ScheduleBatch):
    """Derive non-generation (max_new_tokens==0) forward fields from reqs.    token_type_ids gates on presence, not is_prefill_only: a missing
    tensor makes bert/roberta silently fall back to zeros.
    """
    if self.is_prefill_only:
        # 原有逻辑:dimensions (Matryoshka), return_pooled_hidden_states, multi_item_delimiter_indices
        if batch.model_config.is_matryoshka and any(
            r.dimensions is not None for r in batch.reqs
        ):
            self.dimensions = [
                r.dimensions if r.dimensions else batch.model_config.hidden_size
                for r in batch.reqs
            ]
        self.return_pooled_hidden_states = any(
            r.return_pooled_hidden_states for r in batch.reqs
        )
        if get_global_server_args().enable_mis and any(
            r.multi_item_delimiter_indices is not None for r in batch.reqs
        ):
            assert all(
                r.multi_item_delimiter_indices is not None for r in batch.reqs
            ), "MIS batch must have delimiter indices on every request"
            self.multi_item_delimiter_indices = [
                torch.tensor(r.multi_item_delimiter_indices, dtype=torch.int64)
                for r in batch.reqs
            ]
​
    # 新增:从每个 req 收集 token_type_ids,若存在则构建设备张量
    # 注意:当前实现使用 sum(list_of_lists, []) 展平,复杂度 O(N^2)
    token_type_ids = [
        r.token_type_ids for r in batch.reqs if r.token_type_ids is not None
    ]
    if token_type_ids:
        self.token_type_ids = torch.tensor(
            sum(token_type_ids, []),
            dtype=torch.int64,
            pin_memory=is_pin_memory_available(self.device),
        ).to(self.device, non_blocking=True)

评论区精华

使用 O(N^2) sum 展平列表的性能问题 性能

gemini-code-assist[bot] 指出使用 `sum(list_of_lists, [])` 展平 token_type_ids 是 O(N^2) 反模式,建议改用嵌套列表推导式。

结论:建议未被采用或回复,PR 直接合并。 · unresolved

风险与影响

  1. 性能风险:当前实现仍使用 sum(list_of_lists, []) 展平 token_type_ids,对于大批量请求可能存在 O(N^2) 的性能问题,但通常 token_type_ids 长度较小,实际影响有限。
  2. 回归风险:token_type_ids 的构建逻辑从 prepare_for_extend 移到了 init_new,且只有 is_prefill_only 时才会执行,需要确认所有使用 token_type_ids 的场景(如上文提到的 bert/roberta)均满足此条件。
  3. 兼容性风险:ScheduleBatch 不再暴露 token_type_ids 字段,任何直接访问该字段的外部调用都会出错。

影响范围较小,仅涉及两个核心文件。主要影响使用 cross-encoder(如 bert/roberta)且启用了 is_prefill_only 模式的请求路径。对于普通 decode 路径无影响。

潜在性能反模式 缺失测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论