执行摘要
- 一句话:为评分API添加返回pooled隐藏状态功能,支持序列分类和奖励模型。
- 推荐动作:此PR值得精读,特别关注池化层和调度器中的设计决策:如何平衡功能需求与性能(如CUDA图处理)、如何处理混合batch中的标志聚合,以及模型文件中的条件返回模式。这些决策对类似API扩展有参考价值。
功能与动机
根据PR body,此功能旨在为下游消费者提供模型内部表示,用于蒸馏、可解释性或二次评分管道。引用PR body原话:'This is useful for downstream consumers that need the model's internal representation alongside the final scores — e.g., for distillation, interpretability, or secondary scoring pipelines.'
实现拆解
- 核心数据结构扩展:在
EmbeddingPoolerOutput(pooler.py)和ScoreResult(tokenizer_manager_score_mixin.py)中添加pooled_hidden_states字段,用于存储任务头前的隐藏状态。
- 池化逻辑更新:在
pooler.py中新增pool_hidden_states()函数实现LAST/CLS池化,并修改score_and_pool()在单项目和多项目评分(MIS)路径下条件性捕获隐藏状态。
- 请求流水线贯通:在
ForwardBatch、Req、ScheduleBatch等数据结构中添加return_pooled_hidden_states标志,并通过调度器(scheduler.py)和输出处理器(scheduler_output_processor_mixin.py)传递,确保从请求到响应的完整链路。
- 模型适配:更新多个序列分类和奖励模型文件(如llama_reward.py、qwen2_rm.py),在forward方法中根据标志返回pooled hidden states。
- 测试与配置配套:新增测试文件
test_pooled_hidden_states.py覆盖Engine API、HTTP集成和错误场景;在server_args.py中添加MIS模式下的CUDA图禁用逻辑,确保兼容性。
关键文件:
python/sglang/srt/layers/pooler.py(模块 池化层;类别 source;类型 core-logic;符号 pool_hidden_states, score_and_pool, EmbeddingPoolerOutput): 核心池化层,新增pool_hidden_states函数并扩展score_and_pool以支持pooled hidden states捕获,是功能实现的基础。
python/sglang/srt/managers/tokenizer_manager_score_mixin.py(模块 评分管理器;类别 source;类型 dependency-wiring;符号 ScoreResult, _process_multi_item_scoring_results, _process_single_item_scoring_results): 评分结果处理的核心文件,扩展ScoreResult以包含pooled_hidden_states,并更新处理逻辑以支持新标志。
python/sglang/srt/models/llama_reward.py(模块 模型适配;类别 source;类型 data-contract;符号 LlamaForSequenceClassification.forward, LlamaForSequenceClassificationWithNormal_Weights.forward): 关键模型文件,展示如何在新功能中适配序列分类模型的forward方法,条件性返回pooled hidden states。
test/registered/prefill_only/test_pooled_hidden_states.py(模块 测试套件;类别 test;类型 test-coverage;符号 TestPooledHiddenStatesEngine, TestPooledHiddenStatesMISEngine, TestPooledHiddenStatesHTTP, TestPooledHiddenStatesCausalLMRejection): 新增的全面测试文件,覆盖Engine API、HTTP集成、MIS模式和错误处理,确保功能正确性和稳定性。
关键符号:pool_hidden_states, score_and_pool, LlamaForSequenceClassification.forward, TokenizerManagerScoreMixin.score_request, Scheduler.run_batch
关键源码片段
python/sglang/srt/layers/pooler.py
核心池化层,新增pool_hidden_states函数并扩展score_and_pool以支持pooled hidden states捕获,是功能实现的基础。
def pool_hidden_states(
pooling_type: PoolingType,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
"""Pool hidden_states by PoolingType (LAST/CLS).
Raw pooling only — no normalize, no dim truncation.
Returns shape (batch_size, hidden_size).
"""
if pooling_type == PoolingType.LAST:
# 提取每个序列的最后一个token的隐藏状态
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
return hidden_states[last_token_indices]
elif pooling_type == PoolingType.CLS:
# 提取每个序列的第一个token(CLS)的隐藏状态
prompt_lens = forward_batch.extend_seq_lens
first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
return hidden_states[first_token_flat_indices]
else:
raise ValueError(f"Unsupported pooling type: {pooling_type}")
python/sglang/srt/managers/tokenizer_manager_score_mixin.py
评分结果处理的核心文件,扩展ScoreResult以包含pooled_hidden_states,并更新处理逻辑以支持新标志。
@dataclass(frozen=True, slots=True)
class ScoreResult:
scores: List[List[float]]
prompt_tokens: int = 0
# Per-item pooled hidden states (pre-head transformer output).
# CPU tensors when return_pooled_hidden_states=True; kept as tensors so
# in-process consumers (gRPC, engine API) avoid a .tolist() round-trip.
# The HTTP path converts to lists in serving_score.py before JSON serialization.
# Same layout as scores: one tensor per item (not a single packed 2D tensor).
pooled_hidden_states: Optional[List[Optional[torch.Tensor]]] = None
评论区精华
风险与影响
- 风险:1. 性能开销:当batch中包含混合请求(部分需要pooled hidden states,部分不需要)时,所有请求都会触发隐藏状态的计算和CPU复制,可能增加延迟和内存使用。
2. CUDA图兼容性:MIS模式下强制禁用CUDA图和piecewise CUDA图(通过server_args.py的_handle_multi_item_scoring),可能影响高吞吐场景的性能。
3. 回归风险:变更涉及调度器、池化层和多个模型文件,若新标志处理不当,可能破坏现有评分功能的正确性,尤其是MIS路径下的边界情况。
4. 数据一致性:在scheduler_output_processor_mixin.py中,初始实现存在IndexError风险(混合batch中pooled hidden states长度不匹配),虽已修复,但类似逻辑需持续关注。
- 影响:对用户:新增API参数
return_pooled_hidden_states,为序列分类和奖励模型用户提供模型内部表示,扩展了蒸馏、可解释性分析等下游应用场景,默认关闭不影响现有行为。对系统:在评分流水线中引入额外数据流,增加了复杂性和轻微性能开销,但通过优化(如条件性捕获)和全面测试最小化影响。对团队:需维护新字段和相关逻辑,测试覆盖确保了功能稳定性,但后续需关注CUDA图在MIS模式下的优化。
- 风险标记:性能开销, CUDA图兼容性, 回归风险
关联脉络
- PR #21887 [Ray] Add data parallel (DP) and DP attention support to RayEngine: 同样扩展评分API功能,涉及RayEngine的评分支持,可参考其跨模块变更模式。
- PR #22897 streaming session: trim spec v2 overshoot in cache_finished_req: 涉及评分相关的缓存和会话管理,但焦点不同;此PR补充了评分API的新特性。
参与讨论