Prhub

#42946 [Frontend] Consolidate beam search by BeamSearchMixin.

原始 PR 作者 noooop 合并时间 2026-05-19 15:37 文件变更 9 提交数 3 评论 6 代码增减 +459 / -375

执行摘要

将 beam search 逻辑提取到独立的 Mixin 类中,统一离线与在线入口。

根据 #42267,将 beam search 的实现从 LLM 和 OpenAIServing 类中提取为独立的 mixin,以提高代码模块化和复用性,便于后续维护和扩展。

该 PR 是清晰的重构,值得阅读以理解 vLLM 入口点的 mixin 设计模式。但需注意 review 中未解决的索引 bug 和性能问题,建议在合并前或后续提交中修复。

讨论亮点
  • 索引错误:gemini-code-assist[bot] 指出 online 实现中当 logprobsNone 时索引偏移,可能导致 beam 搜索结果错误(严重)。
  • 性能问题:离线实现中使用 sum([], ...) 展平列表,建议改为 itertools.chain.from_iterable 以避免 O(N^2) 复杂度。
  • 设计耦合:DarkLight1337 对 mixin 依赖外部方法表示疑虑,作者解释“先放置代码,后续再优化 API”,最终批准。

实现拆解

  1. 创建目录结构 vllm/entrypoints/generate/beam_search/,添加 __init__.pyutils.py(从 vllm/beam_search.py 重命名)。
  2. 提取离线 beam search 逻辑到 BeamSearchOfflineMixin,位于 offline.py,该方法原为 LLM 类方法,现在通过多继承注入。
  3. 提取在线 beam search 逻辑到 BeamSearchOnlineMixin,位于 online.py,该方法原为 OpenAIServing 类方法,现在通过继承使用。
  4. 修改 vllm/entrypoints/llm.py:删除内联 beam_search 方法及其导入,改为导入并继承 BeamSearchOfflineMixin
  5. 修改 vllm/entrypoints/openai/engine/serving.py:删除内联 beam_search 方法及相关导入,改为导入并继承 BeamSearchOnlineMixin,并调整 super().__init__() 调用。
  6. 更新 CI 配置:.buildkite/test-amd.yaml.buildkite/test_areas/samplers.yaml 更新测试路径以匹配新文件结构。
文件 模块 状态 重要度
vllm/entrypoints/generate/beam_search/offline.py 入口层 added 9.21
vllm/entrypoints/generate/beam_search/online.py 入口层 added 8.6
vllm/entrypoints/openai/engine/serving.py 入口层 modified 8.53
vllm/entrypoints/llm.py 入口层 modified 8.17
vllm/entrypoints/generate/beam_search/utils.py 入口层 renamed 4.58
vllm/entrypoints/generate/__init__.py 入口层 added 3.3
vllm/entrypoints/generate/beam_search/__init__.py 入口层 added 3.3
.buildkite/test-amd.yaml CI 配置 modified 2.84
.buildkite/test_areas/samplers.yaml CI 配置 modified 2.6

关键符号

BeamSearchOfflineMixin.beam_search BeamSearchOnlineMixin.beam_search LLM.__init__ OpenAIServing.__init__

关键源码片段

vllm/entrypoints/generate/beam_search/offline.py dependency-wiring

新增文件,包含离线 beam search 核心逻辑 BeamSearchOfflineMixin,是重构的核心。

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM projectimport itertools
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
from typing import Anyfrom tqdm import tqdmfrom vllm import PromptType, RequestOutput, TextPrompt, TokensPrompt
from vllm.inputs import EngineInput
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.renderers import BaseRenderer
from vllm.sampling_params import BeamSearchParams, SamplingParamsfrom .utils import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)logger = init_logger(__name__)
​
​
class BeamSearchOfflineMixin(ABC):
    """Offline inference for beam search"""
​
    renderer: BaseRenderer
​
    def beam_search(
        self,
        prompts: list[TokensPrompt | TextPrompt],
        params: BeamSearchParams,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
        use_tqdm: bool = False,
        concurrency_limit: int | None = None,
    ) -> list[BeamSearchOutput]:
        # 从 params 提取配置
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        length_penalty = params.length_penalty
​
        # 获取 tokenizer 并创建排序函数
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
​
        # 预处理 prompt 并生成 lora 请求序列
        engine_inputs = self._preprocess_cmpl(prompts)
        lora_requests = self._lora_request_to_seq(lora_request, len(engine_inputs))
​
        # 配置并发限制
        if concurrency_limit is None:
            concurrency_limit = len(engine_inputs)
​
        # 每个步骤生成 2 * beam_width 候选,参考 HuggingFace Transformers
        sampling_params = SamplingParams(
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True, # 内部 beam search,安全跳过克隆
        )
​
        # 初始化 beam 实例
        instances: list[BeamSearchInstance] = []
        for lora_req, prompt in zip(lora_requests, engine_inputs):
            if prompt["type"] == "embeds":
                raise NotImplementedError("Embedding prompt not supported for beam search")
            instances.append(
                BeamSearchInstance(prompt, lora_request=lora_req, logprobs=None),
            )
​
        # 按批次处理 prompt
        for prompt_start in range(0, len(instances), concurrency_limit):
            # ... 核心循环省略,包括展平所有 beam、生成、评分和选择 top beam
            pass
vllm/entrypoints/generate/beam_search/online.py dependency-wiring

新增文件,包含在线 beam search 核心逻辑 BeamSearchOnlineMixin,是重构的核心。

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM projectimport asyncio
from abc import ABC
from collections.abc import AsyncGenerator, Mappingimport numpy as npfrom vllm import CompletionOutput, RequestOutput
from vllm.engine.protocol import EngineClient
from vllm.inputs import EngineInput
from vllm.lora.request import LoRARequest
from vllm.renderers import BaseRenderer
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.utils import random_uuid
from vllm.utils.async_utils import collect_from_async_generatorfrom .utils import BeamSearchSequence, create_sort_beams_key_function
​
​
class BeamSearchOnlineMixin(ABC):
    """online serving for beam search"""
​
    renderer: BaseRenderer
    engine_client: EngineClient
​
    async def beam_search(
        self,
        prompt: EngineInput,
        request_id: str,
        params: BeamSearchParams,
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
    ) -> AsyncGenerator[RequestOutput, None]:
        # 提取参数
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        length_penalty = params.length_penalty
        # 获取 tokenizer 并创建排序函数
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
​
        if prompt["type"] == "embeds":
            raise NotImplementedError("Embedding prompt not supported for beam search")
​
        # 解码器 prompt 提取
        decoder_prompt = (
            prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
        )
        prompt_token_ids = decoder_prompt["prompt_token_ids"]
​
        # 初始化采样参数,请求 2 * beam_width 个 logprobs
        logprobs_num = 2 * beam_width
        sampling_params = SamplingParams(
            logprobs=logprobs_num,
            max_tokens=1,
            temperature=params.temperature,
        )
​
        all_beams = [
            BeamSearchSequence(
                orig_prompt=prompt,
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                lora_request=lora_request,
            )
        ]
        completed = []
​
        # 循环生成 token,注意:当 logprobs 为 None 时可能存在索引错误(review 指出)
        for _ in range(max_tokens):
            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"
            for i, beam in enumerate(all_beams):
                # 对每个 beam 发起异步生成请求
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
                            beam.get_prompt(),
                            sampling_params,
                            f"{request_id_batch}-beam-{i}",
                            lora_request=beam.lora_request,
                            trace_headers=trace_headers,
                        )
                    )
                )
                tasks.append(task)
            output = [x[0] for x in await asyncio.gather(*tasks)]
            # ... 后续处理省略
vllm/entrypoints/openai/engine/serving.py dependency-wiring

被修改的关键入口文件,通过继承 BeamSearchOnlineMixin 集成在线 beam search。

# 主要变化:继承关系和导入调整
# 旧:class OpenAIServing:
# 新:class OpenAIServing(BeamSearchOnlineMixin):# 新增 import
from vllm.entrypoints.generate.beam_search.online import BeamSearchOnlineMixin# 删除了 beam_search 方法及其依赖的 import (asyncio, numpy, vllm.beam_search 等 )
# 通过继承 mixin 自动获得 beam_search 功能

评论区精华

索引错误风险:当 logprobs 为 None 时索引偏移 正确性

gemini-code-assist[bot] 指出 online.py 中若 result.outputs[0].logprobs 为 None,后续索引会错误映射到错误的父 beam。

结论:未在 review 中看到作者修复确认,合并时可能仍存在此问题。 · unaddressed

性能优化:sum(..., []) 展平效率低 性能

gemini-code-assist[bot] 建议使用 itertools.chain.from_iterable 替代 sum(..., []) 以避免 O(N^2)。

结论:未看到作者采纳或拒绝,建议未落地。 · unaddressed

Mixin 对外部方法的依赖设计讨论 设计

DarkLight1337 对 mixin 依赖 _preprocess_cmpl 等外部方法表示疑虑,作者承诺先放置代码后续优化。

结论:双方达成理解,暂时接受现有设计。 · 已解决

风险与影响

重构本身安全,但 review 中指出的索引 bug(online.py 中 logprobs 为 None 时的索引错位)若未修复则会导致 beam 搜索产生错误结果。此外,离线实现中 sum(..., []) 的 O(N^2) 展平操作在 beams 数量大时可能引入性能下降。当前 PR 合并时未确认是否已修复这些建议,需后续跟进。

对用户:LLM.beam_search 和 OpenAIServing 的 beam search 接口保持不变,行为无变化。对系统:减少了重复代码,提高了模块化,但引入了新的 mixin 依赖关系。对团队:后续扩展 beam search 只需修改 mixin 类,无需改动两个入口类。

索引逻辑易错 未修复的 review 建议 外部依赖耦合

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论