Prhub

#42368 [chore] Refactor pooling metadata token ID accessors

原始 PR 作者 taneem-ibrahim 合并时间 2026-05-13 14:08 文件变更 4 提交数 5 评论 5 代码增减 +142 / -142

执行摘要

重构池化响应构建,消除编码逻辑重复

PR body 指出:“移除 embedding 和 pooling serving 路径间重复的 JSON/Bytes 编码逻辑,将共享辅助函数集中到 utils.py,提高可维护性和一致性,同时保持行为和响应格式不变。”

值得精读,特别是 utils.py 中函数提取的设计模式和 _get_prompt_token_ids 的 DRY 实践。该 PR 是典型的“提取+集中”重构,展示了如何消除跨模块重复逻辑。建议在类似场景中参考其抽象粒度。同时,可关注 review 中关于 None 安全的设计讨论。

讨论亮点
  • gemini-code-assist 指出 None 安全问题:在 get_pooling_usage 中直接 len(output.prompt_token_ids) 可能因 prompt_token_idsNone 而抛出 TypeError,建议增加 if not None else 0 保护。该建议被采纳,最终代码已修改。
  • yewentao256 对关键字参数风格提出异议:建议不要强制使用关键字参数(* 分隔),认为“除非确实需要,否则不应使用”。该评论未导致代码修改,最终函数签名仍使用普通位置参数。
  • yewentao256 和 noooop 先后批准:经过一次 pre-commit 失败修复后,最终获得维护者批准合并。

实现拆解

  1. 集中编码与用量计算(utils.py:新增 get_pooling_output_encoder 根据 encoding_format 返回对应的编码函数(float/base64),替代两端各自 cast + partial 的重复代码。新增 get_pooling_usage / get_pooling_usage_payload 统一从 PoolingRequestOutput.prompt_token_ids 计算 token 用量,并处理 None 情形。修改 encode_pooling_bytes 不再返回 usage,改为纯 bytes 与 metadata 二元组。新增 build_pooling_bytes_streaming_response 封装 streaming response 构造(含 bytes 拼接、metadata header 设置)。
  2. 简化 embedding 服务(embed/serving.py:删除内联的 ndarray_num_tokens 循环计算、usage 构建、encode_fn 手写判断,改为调用 get_pooling_usageget_pooling_output_encoder。bytes 响应路径直接委托 build_pooling_bytes_streaming_response
  3. 简化 pooling 服务(pooling/serving.py:与 embedding 服务类似,删除内联 usage 计算与 encode_fn 选择,统一使用工具函数。bytes 响应同样委托。
  4. 提取 token ID 访问器(v1/pool/metadata.py:将 get_prompt_token_ids 中的断言 + 切片逻辑提取为私有方法 _get_prompt_token_ids,令 get_prompt_token_ids_cpu 也复用该方法,去除代码复制。
  5. 配套调整:更新了所有 import,删除了不再需要的 jsonpartialcastUsageInfo 等导入。未新增测试文件,但所有现有测试应继续通过。
文件 模块 状态 重要度
vllm/entrypoints/pooling/utils.py 工具层 modified 8.06
vllm/entrypoints/pooling/embed/serving.py 嵌入服务 modified 6.42
vllm/entrypoints/pooling/pooling/serving.py 池化服务 modified 6.36
vllm/v1/pool/metadata.py 元数据 modified 6.04

关键符号

encode_pooling_output_float get_pooling_output_encoder get_pooling_usage get_pooling_usage_payload build_pooling_bytes_streaming_response _get_prompt_token_ids get_prompt_token_ids get_prompt_token_ids_cpu

关键源码片段

vllm/entrypoints/pooling/utils.py core-logic

核心变更文件,集中了编码、用量计算和流式响应构建等所有新工具函数,是整个重构的基础。

# vllm/entrypoints/pooling/utils.py
# 定义编码格式字面量类型
JsonEncodingFormat = Literal["float", "base64"]
BytesEncodingFormat = Literal["bytes", "bytes_only"]
FloatEncodedPoolingOutput = list[float] | list[list[float]]
JsonEncodedPoolingOutput = FloatEncodedPoolingOutput | str
​
​
def get_pooling_output_encoder(
    encoding_format: JsonEncodingFormat,
    embed_dtype: EmbedDType,
    endianness: Endianness,
) -> Callable[[PoolingRequestOutput], JsonEncodedPoolingOutput]:
    """根据 encoding_format 返回对应的编码函数,消除两端重复的 cast+partial 逻辑。"""
    return cast(
        Callable[[PoolingRequestOutput], JsonEncodedPoolingOutput],
        (
            encode_pooling_output_float
            if encoding_format == "float"
            else partial(
                encode_pooling_output_base64,
                embed_dtype=embed_dtype,
                endianness=endianness,
            )
        ),
    )
​
​
def get_pooling_usage(
    pooling_outputs: Sequence[PoolingRequestOutput],
) -> UsageInfo:
    """统一计算 prompt_tokens 和 total_tokens,处理 prompt_token_ids 可能为 None 的情况。"""
    num_prompt_tokens = sum(
        len(output.prompt_token_ids) if output.prompt_token_ids is not None else 0
        for output in pooling_outputs
    )
    return UsageInfo(prompt_tokens=num_prompt_tokens, total_tokens=num_prompt_tokens)
​
​
def build_pooling_bytes_streaming_response(
    pooling_outputs: list[PoolingRequestOutput],
    request_id: str,
    created_time: int,
    model_name: str,
    encoding_format: BytesEncodingFormat,
    embed_dtype: EmbedDType,
    endianness: Endianness,
) -> StreamingResponse:
    """构建 bytes 格式的流式响应,封装了 metadata 头部处理。"""
    content, items = encode_pooling_bytes(
        pooling_outputs, embed_dtype, endianness
    )
    usage = get_pooling_usage_payload(pooling_outputs)
    # bytes_only 模式下不加 metadata 头
    headers = None if encoding_format == "bytes_only" else {
        "metadata": json.dumps({
            "items": items,
            "usage": usage,
        })
    }
    return StreamingResponse(
        content=content,
        media_type="application/octet-stream",
        headers=headers,
    )
vllm/v1/pool/metadata.py core-logic

提取 _get_prompt_token_ids 私有方法,消除 get_prompt_token_ids 与 get_prompt_token_ids_cpu 之间的断言 + 切片重复。

# vllm/v1/pool/metadata.py
# 提取的私有方法,统一处理 prompt_token_ids 非空断言与切片
    def _get_prompt_token_ids(
        self,
        prompt_token_ids: torch.Tensor | None,
    ) -> list[torch.Tensor]:
        assert prompt_token_ids is not None, (
            "Please set `requires_token_ids=True` in `get_pooling_updates`"
        )
        # 按每个序列的实际 prompt 长度截取 token ids
        return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
​
    def get_prompt_token_ids(self) -> list[torch.Tensor]:
        # 委托给内部方法,使用 GPU 侧的 prompt_token_ids
        return self._get_prompt_token_ids(self.prompt_token_ids)
​
    def get_prompt_token_ids_cpu(self) -> list[torch.Tensor]:
        # 同样委托,使用 CPU 侧的 prompt_token_ids
        return self._get_prompt_token_ids(self.prompt_token_ids_cpu)

评论区精华

get_pooling_usage 中 prompt_token_ids 可能为 None 的安全风险 正确性

gemini-code-assist 指出 get_pooling_usage 直接调用 len(output.prompt_token_ids) 未检测 None,在 V1 引擎中可能导致 TypeError 和 500 错误,建议增加 None 检查。

结论:已采纳,最终代码使用 `len(output.prompt_token_ids) if output.prompt_token_ids is not None else 0` 处理。 · 已解决

是否应使用强制关键字参数 style

yewentao256 查看 get_pooling_output_encoder 等函数后建议不要使用关键字参数,原话“除非确实需要,否则不应使用”。

结论:未采纳,最终函数仍使用普通位置参数,未强制关键字。 · no-action

Pre-commit 失败修复 other

mergify[bot] 两次提示 pre-commit 检查失败,作者后续提交修复。

结论:通过修复 pre-commit 问题后,yewentao256 和 noooop 依次批准。 · 已解决

风险与影响

  • 兼容性风险:重构保持响应格式不变,但若下游依赖内部 encode_pooling_bytes 的旧返回格式(原来返回 (body, items, usage) 三元组,现为 (body, items) 二元组),可能出现解包错误。经检查,encode_pooling_bytes 仅被 embed/serving.pypooling/serving.py 使用,且两者均已适配新签名,无外部调用风险。
  • None 安全风险get_pooling_usage 中若 prompt_token_idsNone 会导致崩溃。该问题已在 review 中被指出并修复,现用 len(output.prompt_token_ids) if output.prompt_token_ids is not None else 0 处理。
  • 回归风险:PR 未新增测试,依赖现有测试覆盖。如果现有测试对 tokens 用量的精确性要求不足,可能遗漏边缘情况。建议后续补充单元测试。
  • 性能风险:无,重构仅涉及响应组装逻辑,无额外开销。
  • 用户影响:无,响应格式与行为完全一致。
  • 系统影响:降低了 embedding 和 pooling 服务的代码复杂度,使后续添加新编码格式或变更响应结构更加容易。PoolingMetadata 的 token ID 访问器更易于维护。
  • 团队影响:减少了两个服务路径之间的知识碎片,新团队更容易理解响应构造流程。但缺少测试覆盖是一点隐患,建议后续补加。
None 安全风险已修复 无新增测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论