执行摘要
- 一句话:将 beam search 逻辑提取到独立的 Mixin 类中,统一离线与在线入口。
- 推荐动作:该 PR 是清晰的重构,值得阅读以理解 vLLM 入口点的 mixin 设计模式。但需注意 review 中未解决的索引 bug 和性能问题,建议在合并前或后续提交中修复。
功能与动机
根据 #42267,将 beam search 的实现从 LLM 和 OpenAIServing 类中提取为独立的 mixin,以提高代码模块化和复用性,便于后续维护和扩展。
实现拆解
- 创建目录结构
vllm/entrypoints/generate/beam_search/,添加 __init__.py 和 utils.py(从 vllm/beam_search.py 重命名)。
- 提取离线 beam search 逻辑到
BeamSearchOfflineMixin,位于 offline.py,该方法原为 LLM 类方法,现在通过多继承注入。
- 提取在线 beam search 逻辑到
BeamSearchOnlineMixin,位于 online.py,该方法原为 OpenAIServing 类方法,现在通过继承使用。
- 修改
vllm/entrypoints/llm.py:删除内联 beam_search 方法及其导入,改为导入并继承 BeamSearchOfflineMixin。
- 修改
vllm/entrypoints/openai/engine/serving.py:删除内联 beam_search 方法及相关导入,改为导入并继承 BeamSearchOnlineMixin,并调整 super().__init__() 调用。
- 更新 CI 配置:
.buildkite/test-amd.yaml 和 .buildkite/test_areas/samplers.yaml 更新测试路径以匹配新文件结构。
关键文件:
vllm/entrypoints/generate/beam_search/offline.py(模块 入口层;类别 source;类型 dependency-wiring;符号 BeamSearchOfflineMixin, beam_search, _preprocess_cmpl, _lora_request_to_seq): 新增文件,包含离线 beam search 核心逻辑 BeamSearchOfflineMixin,是重构的核心。
vllm/entrypoints/generate/beam_search/online.py(模块 入口层;类别 source;类型 dependency-wiring;符号 BeamSearchOnlineMixin, beam_search): 新增文件,包含在线 beam search 核心逻辑 BeamSearchOnlineMixin,是重构的核心。
vllm/entrypoints/openai/engine/serving.py(模块 入口层;类别 source;类型 dependency-wiring;符号 OpenAIServing, beam_search): 被修改的关键入口文件,通过继承 BeamSearchOnlineMixin 集成在线 beam search。
vllm/entrypoints/llm.py(模块 入口层;类别 source;类型 dependency-wiring;符号 LLM, beam_search): 被修改的关键入口文件,通过继承 BeamSearchOfflineMixin 集成离线 beam search。
vllm/entrypoints/generate/beam_search/utils.py(模块 入口层;类别 source;类型 rename-or-move;符号 BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, create_sort_beams_key_function): 从 vllm/beam_search.py 重命名,包含 BeamSearchSequence 等工具类,是 mixin 的依赖。
vllm/entrypoints/generate/__init__.py(模块 入口层;类别 source;类型 core-logic): 新增空包文件,标识 generate 子包。
vllm/entrypoints/generate/beam_search/__init__.py(模块 入口层;类别 source;类型 core-logic): 新增空包文件,标识 beam_search 子包。
.buildkite/test-amd.yaml(模块 CI配置;类别 config;类型 configuration): 更新测试路径以匹配 beam_search 文件移动。
.buildkite/test_areas/samplers.yaml(模块 CI配置;类别 config;类型 configuration): 更新 samplers 测试区域的 beam search 测试路径。
关键符号:BeamSearchOfflineMixin.beam_search, BeamSearchOnlineMixin.beam_search, LLM.init, OpenAIServing.init
关键源码片段
vllm/entrypoints/generate/beam_search/offline.py
新增文件,包含离线 beam search 核心逻辑 BeamSearchOfflineMixin,是重构的核心。
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
from typing import Any
from tqdm import tqdm
from vllm import PromptType, RequestOutput, TextPrompt, TokensPrompt
from vllm.inputs import EngineInput
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.renderers import BaseRenderer
from vllm.sampling_params import BeamSearchParams, SamplingParams
from .utils import (
BeamSearchInstance,
BeamSearchOutput,
BeamSearchSequence,
create_sort_beams_key_function,
)
logger = init_logger(__name__)
class BeamSearchOfflineMixin(ABC):
"""Offline inference for beam search"""
renderer: BaseRenderer
def beam_search(
self,
prompts: list[TokensPrompt | TextPrompt],
params: BeamSearchParams,
lora_request: list[LoRARequest] | LoRARequest | None = None,
use_tqdm: bool = False,
concurrency_limit: int | None = None,
) -> list[BeamSearchOutput]:
# 从 params 提取配置
beam_width = params.beam_width
max_tokens = params.max_tokens
temperature = params.temperature
length_penalty = params.length_penalty
# 获取 tokenizer 并创建排序函数
tokenizer = self.renderer.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
# 预处理 prompt 并生成 lora 请求序列
engine_inputs = self._preprocess_cmpl(prompts)
lora_requests = self._lora_request_to_seq(lora_request, len(engine_inputs))
# 配置并发限制
if concurrency_limit is None:
concurrency_limit = len(engine_inputs)
# 每个步骤生成 2 * beam_width 候选,参考 HuggingFace Transformers
sampling_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
skip_clone=True, # 内部 beam search,安全跳过克隆
)
# 初始化 beam 实例
instances: list[BeamSearchInstance] = []
for lora_req, prompt in zip(lora_requests, engine_inputs):
if prompt["type"] == "embeds":
raise NotImplementedError("Embedding prompt not supported for beam search")
instances.append(
BeamSearchInstance(prompt, lora_request=lora_req, logprobs=None),
)
# 按批次处理 prompt
for prompt_start in range(0, len(instances), concurrency_limit):
# ... 核心循环省略,包括展平所有 beam、生成、评分和选择 top beam
pass
vllm/entrypoints/generate/beam_search/online.py
新增文件,包含在线 beam search 核心逻辑 BeamSearchOnlineMixin,是重构的核心。
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from abc import ABC
from collections.abc import AsyncGenerator, Mapping
import numpy as np
from vllm import CompletionOutput, RequestOutput
from vllm.engine.protocol import EngineClient
from vllm.inputs import EngineInput
from vllm.lora.request import LoRARequest
from vllm.renderers import BaseRenderer
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.utils import random_uuid
from vllm.utils.async_utils import collect_from_async_generator
from .utils import BeamSearchSequence, create_sort_beams_key_function
class BeamSearchOnlineMixin(ABC):
"""online serving for beam search"""
renderer: BaseRenderer
engine_client: EngineClient
async def beam_search(
self,
prompt: EngineInput,
request_id: str,
params: BeamSearchParams,
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
) -> AsyncGenerator[RequestOutput, None]:
# 提取参数
beam_width = params.beam_width
max_tokens = params.max_tokens
length_penalty = params.length_penalty
# 获取 tokenizer 并创建排序函数
tokenizer = self.renderer.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
if prompt["type"] == "embeds":
raise NotImplementedError("Embedding prompt not supported for beam search")
# 解码器 prompt 提取
decoder_prompt = (
prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
)
prompt_token_ids = decoder_prompt["prompt_token_ids"]
# 初始化采样参数,请求 2 * beam_width 个 logprobs
logprobs_num = 2 * beam_width
sampling_params = SamplingParams(
logprobs=logprobs_num,
max_tokens=1,
temperature=params.temperature,
)
all_beams = [
BeamSearchSequence(
orig_prompt=prompt,
tokens=prompt_token_ids,
cum_logprob=0,
logprobs=[],
lora_request=lora_request,
)
]
completed = []
# 循环生成 token,注意:当 logprobs 为 None 时可能存在索引错误(review 指出)
for _ in range(max_tokens):
tasks = []
request_id_batch = f"{request_id}-{random_uuid()}"
for i, beam in enumerate(all_beams):
# 对每个 beam 发起异步生成请求
task = asyncio.create_task(
collect_from_async_generator(
self.engine_client.generate(
beam.get_prompt(),
sampling_params,
f"{request_id_batch}-beam-{i}",
lora_request=beam.lora_request,
trace_headers=trace_headers,
)
)
)
tasks.append(task)
output = [x[0] for x in await asyncio.gather(*tasks)]
# ... 后续处理省略
vllm/entrypoints/openai/engine/serving.py
被修改的关键入口文件,通过继承 BeamSearchOnlineMixin 集成在线 beam search。
# 主要变化:继承关系和导入调整
# 旧:class OpenAIServing:
# 新:class OpenAIServing(BeamSearchOnlineMixin):
# 新增 import
from vllm.entrypoints.generate.beam_search.online import BeamSearchOnlineMixin
# 删除了 beam_search 方法及其依赖的 import (asyncio, numpy, vllm.beam_search 等 )
# 通过继承 mixin 自动获得 beam_search 功能
评论区精华
风险与影响
- 风险:重构本身安全,但 review 中指出的索引 bug(online.py 中 logprobs 为 None 时的索引错位)若未修复则会导致 beam 搜索产生错误结果。此外,离线实现中 sum(..., []) 的 O(N^2) 展平操作在 beams 数量大时可能引入性能下降。当前 PR 合并时未确认是否已修复这些建议,需后续跟进。
- 影响:对用户:
LLM.beam_search 和 OpenAIServing 的 beam search 接口保持不变,行为无变化。对系统:减少了重复代码,提高了模块化,但引入了新的 mixin 依赖关系。对团队:后续扩展 beam search 只需修改 mixin 类,无需改动两个入口类。
- 风险标记:索引逻辑易错, 未修复的 review 建议, 外部依赖耦合
关联脉络
参与讨论