执行摘要
- 一句话:去重 PD 分离的 logprob 归一化逻辑
- 推荐动作:值得快速合入,这是典型的代码去重重构,提升一致性和可维护性。建议阅读
batch_result_processor.py 中的 move_logprobs_to_cpu 方法,理解共享的 logprob 归一化契约。
功能与动机
PR body 明确说明:PD prefill 路径过去手动内联了与 batch_result_processor._move_logprobs_to_cpu 几乎相同的 tolist() 转换逻辑,导致两份重复代码。本 PR 目的是让 PD prefill 复用共享的归一化辅助函数,确保 prefill 路径对 top-logprob 行使用相同的生产者侧契约,并在现有 disagg 测试中加强断言。
实现拆解
- 暴露私有方法为公有:在
batch_result_processor.py 中将 _move_logprobs_to_cpu 重命名为 move_logprobs_to_cpu(去掉下划线前缀),使其可从外部访问。
- PD prefill 调用共享方法:在
prefill.py 的 process_batch_result_disagg_prefill 中,删除约 20 行内联的 logprob 归一化代码,替换为一行调用 self.batch_result_processor.move_logprobs_to_cpu(...)。
- 强化测试断言:在
test_disaggregation_basic.py 的 test_chat_completion_top_logprobs 中,将 assertGreater(len(top_logprobs), 0) 收紧为 assertEqual(len(top_logprobs), 5)(严格匹配请求的 top_logprobs=5),并新增 assertIsInstance 检查 logprob 类型为 float。
关键文件:
python/sglang/srt/disaggregation/prefill.py(模块 解耦调度;类别 source;类型 core-logic): PD prefill 入口,删除了内联 logprob 归一化代码,改为调用共享方法。
python/sglang/srt/managers/scheduler_components/batch_result_processor.py(模块 批处理;类别 source;类型 core-logic;符号 _move_logprobs_to_cpu, move_logprobs_to_cpu): 共享的 logprob 归一化方法从私有变为公有,普通 prefill 调用也同步更新。
test/registered/disaggregation/test_disaggregation_basic.py(模块 解耦调度;类别 test;类型 test-coverage): 测试断言收紧,验证 top_k 数量和 logprob 类型。
关键符号:move_logprobs_to_cpu, process_batch_result_disagg_prefill
关键源码片段
python/sglang/srt/disaggregation/prefill.py
PD prefill 入口,删除了内联 logprob 归一化代码,改为调用共享方法。
# 替换后的函数片段:
def process_batch_result_disagg_prefill(
self,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
) -> None:
# ... 前置代码保持不变 ...
logprob_pt = 0
next_token_ids = result.next_token_ids.tolist()
# 调用共享的 logprob 归一化方法,替代之前的内联重复代码
self.batch_result_processor.move_logprobs_to_cpu(
batch=batch,
logits_output=logits_output,
)
# ... 后续代码不变 ...
python/sglang/srt/managers/scheduler_components/batch_result_processor.py
共享的 logprob 归一化方法从私有变为公有,普通 prefill 调用也同步更新。
# 方法重命名:移去下划线前缀,暴露为公有方法
def move_logprobs_to_cpu(
self,
*,
batch: ScheduleBatch,
logits_output: LogitsProcessorOutput,
) -> None:
# 仅当 batch 要求 logprob 时才进行转换
if batch.return_logprob:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
# top_logprobs 相关字段也做同样转换
if logits_output.next_token_top_logprobs_val:
logits_output.next_token_top_logprobs_val = [
v.tolist() for v in logits_output.next_token_top_logprobs_val
]
logits_output.next_token_top_logprobs_idx = [
x.tolist() for x in logits_output.next_token_top_logprobs_idx
]
if logits_output.next_token_token_ids_logprobs_val:
logits_output.next_token_token_ids_logprobs_val = [
v.tolist() for v in logits_output.next_token_token_ids_logprobs_val
]
test/registered/disaggregation/test_disaggregation_basic.py
测试断言收紧,验证 top_k 数量和 logprob 类型。
# 测试函数中的收紧断言
first_top_logprobs = next(
(item.top_logprobs for item in content_logprobs if item.top_logprobs),
None,
)
self.assertIsNotNone(first_top_logprobs)
# 严格验证请求的 top_k 数量
self.assertEqual(len(first_top_logprobs), 5)
self.assertIsInstance(first_top_logprobs[0].token, str)
# 新增类型断言,确保 logprob 是 float 而非整数或其他类型
self.assertIsInstance(first_top_logprobs[0].logprob, float)
评论区精华
风险与影响
- 风险:低风险。仅涉及将已存在的逻辑抽取为公有方法并替换调用,行为语义一致;测试断言的收紧有助于减少回归可能。但需确认
move_logprobs_to_cpu 在 PD prefill 执行环境中(可能缺少某些 batch 上下文)行为与之前内联版本完全一致——本 PR 的调用传递了 batch 和 logits_output,与内联代码的条件 batch.return_logprob 一致,因此风险很低。
- 影响:影响范围有限:仅影响 PD 分离场景的 prefill 处理逻辑(
prefill.py),代码量减少约 20 行,降低维护成本。测试断言的收紧可防止 top_k 数量或类型在后续改动中意外偏离。对其他模块无影响。
- 风险标记:低风险重构, 测试增强
关联脉络
- PR #27004 fix(disagg): correct DSA/SWA state-page transfer mismatch in PD disaggregation: 同一作者近期修复了 PD 分离中的数据页面传输错误,本 PR 是对同一模块(
prefill.py)的代码质量清理。
参与讨论