Prhub

#39937 [Model Runner V2] Multiple prompt logprobs support

原始 PR 作者 yewentao256 合并时间 2026-04-21 23:49 文件变更 1 提交数 6 评论 9 代码增减 +48 / -18

执行摘要

V2 Model Runner 支持多个 prompt logprob

作为 #39337 的一部分,需要支持多个 prompt logprobs(top-k),并修复在启用 chunked prefill 和 preemption 时 prompt_logprobs 断言失败的问题。PR body 中的测试用例原本因位置只有 1 个 logprob 而报错 AssertionError: Output 0 position 1 has 1 logprobs, expected 2

值得精读以实现差异化批处理参数的传递模式。但需留意 gemini 指出的切片缺失问题,建议在合并后续 PR 或生产环境遇到断言错误时补充每个请求的精确截断逻辑。

讨论亮点
  • uses_prompt_logprobs 冗余性(design):njhill 询问引入 num_prompt_logprobs 后是否可移除 uses_prompt_logprobs;yewentao256 解释因 prompt_logprobs=0 是有效启用状态,需保留。
  • 中间变量简化(style):njhill 建议为 num_prompt_logprobs[needs_prompt_logprobs]is_prompt_chunked[i] 使用中间变量以避免重复索引,均已采纳。
  • 切片未达预期(correctness,来自 gemini-code-assist[bot]):当批处理中请求的 prompt_logprobs 值不同时,当前实现按最大值计算并返回所有位置的完整 top-k,而下游测试断言期望精确数量,可能导致 AssertionError。该评论未获直接回复,最终代码仍保持按批最大值返回,未按请求单独切片。

实现拆解

  1. 状态追踪:在 PromptLogprobsWorker.__init__ 中新增 self.num_prompt_logprobs 数组(np.int32),用于持久化每个请求的 sampling_params.prompt_logprobs
  2. 参数记录:在 add_request 中将 sampling_params.prompt_logprobs or 0 写入对应位置,同时保留 uses_prompt_logprobs 布尔数组以区分请求是否启用(prompt_logprobs=0 属于启用但无需返回)。
  3. 批处理最大值计算:在 compute_prompt_logprobs 中从 num_prompt_logprobs 提取当前批的各请求所需数量,计算 max_num_prompt_logprobs:若任一请求为 -1(请求所有 token),则取 -1;否则取整数最大值。
  4. 向下游传递:将 max_num_prompt_logprobs 作为额外参数传入 compute_prompt_logprobs_with_chunking,使其在计算时按全局最高数量生成 logprob 张量。该函数返回三元组 prompt_token_ids, prompt_logprobs, prompt_ranks(之前为二元组)。
  5. 结果收集:遍历批中请求,根据 is_prompt_chunkedstart_idx >= end_idx 条件选择构造 LogprobsTensorsNone;合并分块结果时对 None 作跳过处理。

涉及文件:仅 vllm/v1/worker/gpu/sample/prompt_logprob.py。无测试、配置或 schema 配套变更。

文件 模块 状态 重要度
vllm/v1/worker/gpu/sample/prompt_logprob.py 采样器 modified 6.83

关键符号

PromptLogprobsWorker.__init__ PromptLogprobsWorker.add_request PromptLogprobsWorker.compute_prompt_logprobs compute_prompt_logprobs_with_chunking

关键源码片段

vllm/v1/worker/gpu/sample/prompt_logprob.py core-logic

核心文件,实现多 prompt logprob 的追踪、计算与结果收集。改动集中在 PromptLogprobsWorker 类。

# vllm/v1/worker/gpu/sample/prompt_logprob.py
# 类初始化:新增 num_prompt_logprobs 数组
class PromptLogprobsWorker:
    def __init__(self, max_num_reqs: int):
        self.max_num_reqs = max_num_reqs
        self.uses_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
        self.num_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=np.int32) # 新增:记录每个请求的 prompt_logprobs 数量
        self.in_progress_prompt_logprobs: dict[str, list[LogprobsTensors]] = {}
​
    def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams):
        uses_prompt_logprobs = sampling_params.prompt_logprobs is not None
        self.uses_prompt_logprobs[req_idx] = uses_prompt_logprobs
        # 保存具体数值,0 表示启用但不需要返回(与 uses_prompt_logprobs 协同)
        self.num_prompt_logprobs[req_idx] = sampling_params.prompt_logprobs or 0
        if uses_prompt_logprobs:
            self.in_progress_prompt_logprobs[req_id] = []
​
    def compute_prompt_logprobs(self, ...):
        # ... ( 省略数据准备 )
        # 获取当前批中各请求的 prompt_logprobs 数量
        num_prompt_logprobs = self.num_prompt_logprobs[idx_mapping_np]
        # ... 过滤出真正需要计算的请求
        requested_num_prompt_logprobs = num_prompt_logprobs[needs_prompt_logprobs]
        # 批最大值:若任一请求为 -1(表示全部),则整体取 -1
        max_num_prompt_logprobs = (
            -1 if np.any(requested_num_prompt_logprobs == -1)
            else int(requested_num_prompt_logprobs.max())
        )
        # 将最大值传入下游函数,确保一次计算产出足够多的 top-k
        prompt_token_ids, prompt_logprobs, prompt_ranks = (
            compute_prompt_logprobs_with_chunking(
                ..., max_num_prompt_logprobs, # 新增参数
            )
        )
        # ... 按请求切片,但未对 prompt_logprobs 的第二维进行按需裁剪
        # 注意:这里仅切了 token 位置维度,而 top-k 维度始终为 max_num_prompt_logprobs
        logprobs = LogprobsTensors(
            logprob_token_ids=prompt_token_ids[start_idx:end_idx],
            logprobs=prompt_logprobs[start_idx:end_idx], # shape: [positions, max_num_prompt_logprobs]
            selected_token_ranks=prompt_ranks[start_idx:end_idx],
        )
        # ...

评论区精华

uses_prompt_logprobs 是否冗余 设计

njhill 提问引入 num_prompt_logprobs 后可否移除 uses_prompt_logprobs;yewentao256 回应因 prompt_logprobs=0 是启用态,需保留 uses_prompt_logprobs 做布尔过滤。

结论:保留 uses_prompt_logprobs,num_prompt_logprobs 补充数值信息。 · 已解决

批处理最大值未按请求切片 正确性

gemini-code-assist[bot] 指出当批中请求的 prompt_logprobs 不同时,所有请求收到最大值个数的 logprobs,可能导致下游断言失败。建议按请求实际数量切片。

结论:未直接回复,最终代码未实现切片,风险遗留。 · unresolved

代码简化建议 style

njhill 建议使用中间变量避免重复索引(num_prompt_logprobs[needs_prompt_logprobs] 和 is_prompt_chunked[i]),并简化 logprobs 赋值为三元表达式。

结论:已采纳,最终代码应用了这些简化。 · 已解决

风险与影响

回归/正确性风险:混合请求批处理时未根据请求实际 prompt_logprobs 对返回张量进行切片,导致所有请求获得相同最大数量的 top-k logprobs。若下游组件(如测试断言)检查精确数量,可能触发断言失败。PR body 中的测试仅使用了统一值 2,未暴露该问题。
性能风险:引入 max_num_prompt_logprobs 计算和 compute_prompt_logprobs_with_chunking 的额外维度,但影响局限于启用 prompt logprobs 的请求,通常为少数。
兼容性:仅影响 V2 Model Runner(VLLM_USE_V2_MODEL_RUNNER=1),V1 runner 不受影响。函数签名变化(返回三元组)已同步至调用方。

用户:V2 Model Runner 用户可指定 sampling_params.prompt_logprobs 为大于 0 的整数来获取多个 prompt logprobs,但混合批处理时可能收到超出预期的 logprobs 数量,需要下游容忍。
系统:新增数组和分支,但仅在请求启用时产生额外计算,影响有限。
团队:代码简化后维护性提升;未解决的切片问题需后续跟踪(如 #39337 后续 PR)。

混合请求未按需切片 缺少多请求差异化测试覆盖 需下游容忍额外 logprobs

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论