Prhub

#42329 [Bugfix][Frontend] Default max_tokens server-side on /inference/v1/generate

原始 PR 作者 hallerite 合并时间 2026-05-13 17:16 文件变更 5 提交数 4 评论 2 代码增减 +168 / -2

执行摘要

为 /inference/v1/generate 添加服务端 max_tokens 默认值,防止静默截断。

SamplingParams.max_tokens 在数据类级别默认值为 16,/inference/v1/generate 端点未做任何处理,导致客户端省略 max_tokens 时生成在 16 个 token 后静默截断。OpenAI 兼容端点已用 get_max_tokens 处理此问题,但 disagg 路径完全跳过默认设置,使得多轮对话、长文本生成等场景出现不完整输出且无任何警告。本 PR 旨在消除这一静默错误。

此 PR 值得精读,尤其是 pydantic 模型验证器追踪客户端字段的技巧,该模式可用于其他需要区分“未设置”与“显式默认值”的场景(如 temperaturetop_p 等)。其实现与测试设计清晰,有助于理解 vLLM 请求处理管线的不同层。

讨论亮点

Review 中 gemini-code-assist[bot] 指出了一个潜在问题:get_max_tokens 在提示长度超过模型最大上下文时会引发 ValueError,可能导致 500 内部错误。作者 hallerite 回应称全局异常处理器(api_server.py:269)已将 ValueError 映射为 400 响应,因此无需额外捕获。这个讨论没有进一步争议,PR 获得批准。此外,合并者 NickLucche 在批准时指出,未来应当考虑将前端默认逻辑提取为共享工具,但本次修复范围之外。

实现拆解

  1. 协议层改造vllm/entrypoints/serve/disagg/protocol.py):为 GenerateRequest 添加一个 pydantic model_validator_capture_sampling_params_provided_keys)和一个 PrivateAttr 属性 _sampling_params_provided_keys。当从 JSON 字典解析请求时,验证器提取 sampling_params 字段的所有键并存入集合;当从预构建的 SamplingParams 实例构造时,该属性保持 None,表示所有字段都已显式提供。新增 is_sampling_param_provided 方法供服务层查询。
  2. 服务层初始化vllm/entrypoints/serve/disagg/serving.py):在 ServingTokens.__init__ 中计算 default_sampling_paramsoverride_max_tokens,与 OpenAIServingChat 保持一致,用于服务器端默认值计算。
  3. 服务层请求处理vllm/entrypoints/serve/disagg/serving.py serve_tokens 方法):在将 sampling_params 传递给引擎之前,检查 request.is_sampling_param_provided("max_tokens")。若未提供,则调用 get_max_tokens(导入自 vllm.entrypoints.utils)计算合适的最大值,并将结果赋值给 sampling_params.max_tokens
  4. 测试配套
    • 新增 tests/entrypoints/serve/disagg/test_protocol.py,包含 5 个单元测试,验证省略 max_tokens、显式设置相同或不同值、其他字段独立跟踪、JSON 往返一致性以及内部构造路径都按预期工作。
    • tests/entrypoints/serve/disagg/test_serving_tokens.py 中添加 test_generate_defaults_max_tokens_when_omitted 回归测试,向服务器发送不含 max_tokens 的请求,断言生成的 token 数超过 16(等于 max_model_len - prompt_len)。
    • 修改 tests/entrypoints/openai/test_openai_schema.py,将 POST /inference/v1/generate 加入 LONG_TIMEOUT_SECONDS 组,因为修复后请求可能超过默认 10 秒超时。
文件 模块 状态 重要度
vllm/entrypoints/serve/disagg/protocol.py 协议层 modified 7.58
vllm/entrypoints/serve/disagg/serving.py 服务层 modified 6.7
tests/entrypoints/serve/disagg/test_protocol.py 协议测试 added 7.04
tests/entrypoints/serve/disagg/test_serving_tokens.py 服务测试 modified 5.43
tests/entrypoints/openai/test_openai_schema.py 架构测试 modified 3.32

关键符号

_capture_sampling_params_provided_keys is_sampling_param_provided ServingTokens.__init__ serve_tokens test_omitted_max_tokens_is_not_provided test_explicit_max_tokens_is_provided test_generate_defaults_max_tokens_when_omitted

关键源码片段

vllm/entrypoints/serve/disagg/protocol.py core-logic

核心变更:添加 `_sampling_params_provided_keys` 私有属性和 `model_validator`,区分客户端显式设置的字段与数据类默认值。

# SPDX-License-Identifier: Apache-2.0
from typing import Any
from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator
from vllm.sampling_params import SamplingParamsclass GenerateRequest(BaseModel):
    # ... 其他字段 ...
​
    # 记录客户端在 JSON body 中显式设置的 sampling_params 键。
    # None 表示请求是通过预构建的 SamplingParams 实例构造的,
    # 此时所有字段都视为已显式提供。
    _sampling_params_provided_keys: set[str] | None = PrivateAttr(default=None)
​
    @model_validator(mode="wrap")
    @classmethod
    def _capture_sampling_params_provided_keys(cls, data: Any, handler):
        provided: set[str] | None = None
        if isinstance(data, dict):
            sp = data.get("sampling_params")
            if isinstance(sp, dict):
                provided = set(sp.keys())
        instance = handler(data)
        instance._sampling_params_provided_keys = provided
        return instance
​
    def is_sampling_param_provided(self, name: str) -> bool:
        """检查调用方是否 显式 设置了 `sampling_params.<name>`。        对于从 JSON body 解析的请求,反映原始输入字典。
        对于通过预构建的 SamplingParams 实例构造的请求,
        所有字段都视为已提供,以免服务器端默认值覆盖上游已解析的值。
        """
        if self._sampling_params_provided_keys is None:
            return True
        return name in self._sampling_params_provided_keys
vllm/entrypoints/serve/disagg/serving.py dependency-wiring

服务层集成:在 __init__ 中计算默认参数,在 serve_tokens 中应用服务器端默认值。

# within class ServingTokens(OpenAIServing):
    def __init__(self, engine_client, models, openai_serving_render, *, ...):
        # ... 原有初始化代码 ...
        # 模拟 OpenAIServingChat,以便在客户端省略 max_tokens 时应用服务器端默认值。
        # 没有这个,SamplingParams.max_tokens 会回退到数据类默认值 16,
        # 并静默截断所有生成。
        self.default_sampling_params = self.model_config.get_diff_sampling_param()
        mc = self.model_config
        self.override_max_tokens = (
            self.default_sampling_params.get("max_tokens")
            if mc.generation_config not in ("auto", "vllm")
            else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
        )
​
    async def serve_tokens(self, request: GenerateRequest, raw_request=None):
        # ... 前面的代码 ...
        if not request.is_sampling_param_provided("max_tokens"):
            sampling_params.max_tokens = get_max_tokens(
                max_model_len=self.model_config.max_model_len,
                max_tokens=None,
                input_length=self._extract_prompt_len(engine_input),
                default_sampling_params=self.default_sampling_params,
                override_max_tokens=self.override_max_tokens,
            )
        # ... 继续 ...

评论区精华

ValueError 未处理风险 正确性

gemini-code-assist[bot] 指出 `get_max_tokens` 在提示超过模型上下文长度时会引发 `ValueError`,可能导致 500 错误而未给出适当的 400 响应。因此建议添加 try-except 块。

结论:hallerite 回应全局异常处理器已处理 ValueError 转 400;PR 保持现状并得到批准。 · 已解决

前端逻辑未来重构 设计

合并者 NickLucche 在批准时评论提到,应将前端默认逻辑提取出公共方法以便多端点共享,但此次修复范围不足。

结论:被视为超出范围,未实施,PR 正常合并。 · 已解决

风险与影响

主要风险在于错误处理:如果 get_max_tokens 因提示长度超过模型最大长度而抛出 ValueError,依赖全局异常处理器将其转换为 400 响应。虽然全局处理器确实有此功能,但在某些中间件或自定义错误处理场景下可能不适用,因此建议未来显式捕获以增强鲁棒性。次要风险是内部构造路径的假设(_sampling_params_provided_keys=None 表示全部已提供)可能掩盖某些调用方未正确解析字段的情况;但目前所有内部调用者都使用预解析的 SamplingParams 实例,因此风险较低。此外,新逻辑仅当 max_tokens 未提供时介入,对现有显式提供 max_tokens 的请求无影响。

直接影响:所有调用 /inference/v1/generate 且未显式设置 max_tokens 的客户端将不再获得 16-token 截断,而是生成接近 max_model_len - prompt_len 的长文本。这对于工具调用、长回复、代码生成等场景至关重要。间接影响:服务器端额外计算 get_max_tokens 的开销极小;schemathesis 测试超时增加至 60 秒以应对长生成。团队需要关注不同模型配置下 get_max_tokens 的默认值是否合理(通过 default_sampling_paramsoverride_max_tokens 可调整)。

全局异常处理器依赖错误处理 核心请求路径变更 未显式捕获 ValueError

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论