Prhub

#22544 [Score API] Add Multi-Item Scoring with pre-computed delimiter indices

原始 PR 作者 fortunecookiee 合并时间 2026-04-21 13:50 文件变更 17 提交数 11 评论 40 代码增减 +1072 / -432

执行摘要

引入预计算分隔符索引的 Multi-Item Scoring 优化,消除 GPU 扫描提升性能。

根据PR body描述,MIS原本需要调用者选择分隔符token ID(--multi-item-scoring-delimiter),并在运行时扫描input_ids以定位分隔符,导致每次前向传递在注意力、logits和池化阶段重复GPU操作,且.item().nonzero()调用引入隐式CUDA同步。新方案旨在通过预计算分隔符索引消除这些开销,简化用户配置并提升性能。

建议技术管理者和工程师精读此PR,重点关注:

  1. 预计算索引的数据流设计如何消除GPU扫描,可作为高性能计算中减少同步的案例。
  2. pool_at_delimiter_positions函数实现展示了如何批量处理变长索引并优化设备传输。
  3. 讨论中的权衡决策,如边界处理、性能优化取舍,对类似优化有参考价值。
讨论亮点

review讨论聚焦于正确性、性能和设计权衡:

  • 边界情况处理:gemini-code-assist[bot]指出分隔符在索引0时(空查询)indices_tensor - 1可能导致负索引,开发者回应已在pool_at_delimiter_positions中添加注释说明该情况无害,因第一个分隔符条目总被_process_multi_item_scoring_results丢弃。
  • 性能优化:kpham-sgl建议批量索引传输至GPU以避免隐式阻塞,开发者更新代码将CPU张量堆叠后单次传输,但权衡后未使用pin_memory以避免小张量性能回归。
  • 设计兼容性:kpham-sgl询问与PR #22427的关系及MIS与非MIS请求混合批处理,开发者确认已rebase并修复,且MIS启用后批次内所有请求必须携带分隔符索引。
  • CUDA图支持:kpham-sgl询问未来计划,开发者表示将通过分段CUDA图提供端到端支持。

实现拆解

  1. 服务器参数与入口配置:在python/sglang/srt/server_args.py中新增--enable-mis布尔标志,硬编码MIS_DELIMITER_TOKEN_ID = 9999作为占位符;_handle_multi_item_scoring()函数在初始化后自动禁用CUDA图、基数缓存和分块预填充,并断言使用flashinfer注意力后端,确保配置一致性。
  2. 分隔符索引预计算与传递:在python/sglang/srt/managers/tokenizer_manager_score_mixin.py中,修改_build_multi_item_token_sequence()函数返回分隔符索引列表;tokenization阶段根据项目长度计算索引,作为CPU张量通过io_structschedule_batch传递至forward_batch_info,确保批量内请求一致性。
  3. 核心消费者更新
    • python/sglang/srt/layers/attention/flashinfer_backend.py_process_multi_item_scoring()使用预计算索引和torch.searchsorted替换GPU扫描。
    • python/sglang/srt/layers/logits_processor.pycompute_logprobs_for_multi_item_scoring()基于索引在CPU构建偏移并批量传输至设备。
    • python/sglang/srt/layers/pooler.py:新增pool_at_delimiter_positions()辅助函数,利用预计算索引提取分隔符前隐藏状态,score_and_pool()中优先切片再应用评分头以提升效率。
  4. 调度与输出处理调整python/sglang/srt/managers/scheduler_output_processor_mixin.py修改_is_multi_item_scoring()和logprob计数逻辑,使用预计算索引而非扫描origin_input_ids
  5. 测试配套更新:新增test/registered/prefill_only/test_multi_item_scoring.py包含6个测试类验证服务器参数、生成和分类MIS、一致性和并发性;更新test_score_api.pytest_score_engine.py等现有测试以使用--enable-mis标志;单元测试test/registered/unit/layers/test_pooler_score_and_pool.py重构以验证pool_at_delimiter_positions和MIS路径。
文件 模块 状态 重要度
python/sglang/srt/layers/pooler.py 池化层 modified 7.85
python/sglang/srt/managers/tokenizer_manager_score_mixin.py 分词管理 modified 7.96
test/registered/prefill_only/test_multi_item_scoring.py 评分测试 added 7.76
python/sglang/srt/server_args.py 服务器参数 modified 6.87
python/sglang/srt/layers/logits_processor.py Logits 处理 modified 6.97

关键符号

pool_at_delimiter_positions _build_multi_item_token_sequence compute_logprobs_for_multi_item_scoring _process_multi_item_scoring score_and_pool

关键源码片段

python/sglang/srt/layers/pooler.py core-logic

核心逻辑变更,新增 `pool_at_delimiter_positions` 函数并重构 `score_and_pool` 以使用预计算分隔符索引,直接影响 MIS 性能优化。

def pool_at_delimiter_positions(
    data: torch.Tensor,
    forward_batch: ForwardBatch,
    device: torch.device,
) -> List[torch.Tensor]:
    """Pool a tensor at the position before each MIS delimiter for every request.    Uses pre-computed delimiter indices from ForwardBatch (CPU tensors),
    moves to GPU with non_blocking=True to avoid CUDA syncs.
    """
    all_index_tensors: List[torch.Tensor] = []
    delim_counts: List[int] = []
    offset = 0
    for req_idx, req_seq_len in enumerate(forward_batch.extend_seq_lens_cpu):
        indices_tensor = forward_batch.multi_item_delimiter_indices[req_idx]
        n = len(indices_tensor)
        if n > 0:
            # 注意:如果第一个分隔符在位置 0(空查询),indices - 1 会回绕到 -1。
            # 这是无害的——第一个分隔符条目总被 _process_multi_item_scoring_results 丢弃。
            all_index_tensors.append(indices_tensor + (offset - 1))
        delim_counts.append(n)
        offset += req_seq_len
​
    if all_index_tensors:
        # 批量 CPU 张量堆叠后单次传输至 GPU,减少同步开销
        index_tensor = torch.cat(all_index_tensors).to(device, non_blocking=True)
    else:
        index_tensor = torch.tensor([], dtype=torch.long, device=device)
    return list(data[index_tensor].split(delim_counts))
python/sglang/srt/managers/tokenizer_manager_score_mixin.py dependency-wiring

依赖关系调整,修改 `_build_multi_item_token_sequence` 以返回分隔符索引,启动预计算索引传递链。

def _build_multi_item_token_sequence(
    self, query: List[int], items: List[List[int]], delimiter_token_id: int
) -> Tuple[List[int], List[int]]:
    """Build a single token sequence for multi-item scoring.    Returns:
        Tuple of (combined token sequence, delimiter indices)
    """
    combined_sequence = query[:] # 以查询 token 开始
    delimiter_indices = []
    for item in items:
        delimiter_indices.append(len(combined_sequence)) # 记录当前分隔符位置
        combined_sequence.append(delimiter_token_id) # 添加分隔符 token
        combined_sequence.extend(item) # 添加项目 tokens
    # 为 logprob 提取添加最后一个分隔符
    delimiter_indices.append(len(combined_sequence))
    combined_sequence.append(delimiter_token_id)
    return combined_sequence, delimiter_indices # 返回序列和索引

评论区精华

分隔符索引边界情况处理 正确性

gemini-code-assist[bot] 指出分隔符在索引 0 时可能导致负索引错误,开发者回应已在代码中添加注释说明该情况无害。

结论:接受边界情况,通过注释和后续处理逻辑确保正确性,未修改核心索引计算。 · 已解决

性能优化与索引传输 性能

kpham-sgl 建议批量索引传输至 GPU 以避免隐式阻塞,开发者更新代码将 CPU 张量堆叠后单次传输,但权衡后未使用 pin_memory。

结论:实现批量传输优化,平衡了性能和复杂性,当前方案可接受。 · 已解决

与 PR #22427 的兼容性 设计

kpham-sgl 询问 MIS 与非 MIS 请求混合批处理及返回类型差异的影响,开发者确认已 rebase 并修复,确保批次一致性。

结论:MIS 启用后批次内所有请求必须携带分隔符索引,接口差异已通过测试验证。 · 已解决

CUDA 图支持计划 设计

kpham-sgl 询问未来 CUDA 图支持,开发者表示将通过分段 CUDA 图提供端到端支持。

结论:列为未来改进方向,当前 PR 不包含 CUDA 图兼容性变更。 · unresolved

风险与影响

技术风险包括:

  • 回归风险pooler.pyscore_and_pool()的MIS路径返回张量列表而非单个张量,可能与下游消费者(如PR #22427引入的代码)期望不匹配,需确保接口一致性。
  • 边界条件错误:尽管注释说明空查询时分隔符索引为0的情况可接受,但若相关逻辑未妥善处理负索引,可能引发隐蔽的数据污染。
  • 性能退化:索引预计算和传输优化虽减少同步,但若CPU张量处理不当或批量大小不均,可能引入额外开销。
  • 兼容性风险:硬编码MIS_DELIMITER_TOKEN_ID = 9999要求模型词表大小超过此值,否则嵌入层会抛出索引越界错误,限制了模型适用性。

影响范围:

  • 用户影响:配置简化,只需--enable-mis单标志,自动处理不兼容设置,提升易用性;性能优化可提高评分任务吞吐约4.5%,减少延迟。
  • 系统影响:核心评分路径变更涉及tokenization、调度、前向传递和输出处理多个模块,减少GPU-CPU同步可能降低系统负载,但增加了索引传递的数据结构复杂度。
  • 团队影响:开发者需熟悉预计算索引的传递链条,测试覆盖全面(新增测试文件并更新现有测试),便于后续维护和扩展。
接口变更风险 边界条件处理 性能优化复杂度 模型兼容性限制

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论