执行摘要
- 一句话:为核心池化方法添加单元测试
- 推荐动作:值得精读,尤其是
_make_pooling_cursor 和 _make_metadata 辅助函数的设计——它们封装了复杂的 PoolingMetadata 构造逻辑,使测试代码简洁且易于扩展。同时,测试中对 partial prefill 和 chunked prefill 的边界处理也值得参考。
功能与动机
PR 指出这些池化方法(CLSPool、LastPool、MeanPool、AllPool、StepPool)的单元测试覆盖率为零,且对池化子系统至关重要,因此添加测试以确保正确性并防止回归。
实现拆解
- 编写测试辅助函数:在
tests/model_executor/layers/test_pooler_methods.py 中新增 _make_pooling_cursor 和 _make_metadata,用于简化 PoolingCursor 和 PoolingMetadata 的构造,支持自定义 prompt_lens、num_scheduled_tokens、seq_lens 等参数,为后续测试提供基础设施。
- 序列级池化方法测试:为
CLSPool、LastPool、MeanPool 编写测试类,验证正向提取逻辑(如 extract first/last/mean token)以及异常场景(如 partial prefill 时抛出 AssertionError)。
- 词元级池化方法测试:为
AllPool 和 StepPool 编写测试类,验证多 token 输出和步进池化行为。
- 工厂函数测试:测试
get_seq_pooling_method 和 get_tok_pooling_method 的正确性及默认行为。
- 边界场景补充:根据 review 建议,增加
enable_chunked_prefill=True 的测试用例,覆盖调度 token 数小于实际 prompt 长度的场景。
- Activations 测试:在同一个文件中还包含了池化激活函数的测试,确保 PoolerActivation 的字符串到函数映射正确。
关键文件:
tests/model_executor/layers/test_pooler_methods.py(模块 池化层;类别 test;类型 test-coverage;符号 _make_pooling_cursor, _make_metadata, TestCLSPool, test_extracts_first_token): 新增 499 行测试文件,全面覆盖池化方法和工厂函数。
关键符号:_make_pooling_cursor, _make_metadata, test_extracts_first_token, test_rejects_partial_prefill, test_extracts_last_token, test_partial_prefill_extracts_last_scheduled
关键源码片段
tests/model_executor/layers/test_pooler_methods.py
新增 499 行测试文件,全面覆盖池化方法和工厂函数。
# SPDX-License-Identifier: Apache-2.0
import torch
import pytest
from vllm.model_executor.layers.pooler.seqwise.methods import CLSPool
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata, PoolingStates
_CPU = torch.device("cpu")
def _make_pooling_cursor(
prompt_lens: list[int],
*,
num_scheduled_tokens: list[int] | None = None,
seq_lens: list[int] | None = None,
device: torch.device = _CPU,
) -> PoolingCursor:
"""从每个序列的提示长度构建 PoolingCursor。"""
# 确保输入为 PyTorch tensor,便于后续批处理
prompt_lens_cpu = torch.tensor(prompt_lens, dtype=torch.long)
if num_scheduled_tokens is None:
num_scheduled_tokens_cpu = prompt_lens_cpu.clone()
else:
num_scheduled_tokens_cpu = torch.tensor(num_scheduled_tokens, dtype=torch.long)
if seq_lens is None:
seq_lens_cpu = prompt_lens_cpu.clone()
else:
seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.long)
# 使用累积和计算每个序列的 token 起始 / 结束索引
cumsum = torch.zeros(len(prompt_lens) + 1, dtype=torch.long, device=device)
torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:])
return PoolingCursor(
first_token_indices_gpu=cumsum[: len(prompt_lens)].to(device),
last_token_indices_gpu=(cumsum[1:] - 1).to(device),
prompt_lens_cpu=prompt_lens_cpu,
seq_lens_cpu=seq_lens_cpu,
num_scheduled_tokens_cpu=num_scheduled_tokens_cpu,
)
class TestCLSPool:
"""CLSPool 从每个序列提取第一个 token 作为池化输出。"""
def test_extracts_first_token(self):
# 构造 2 个序列,prompt_lens 分别为 2 和 3,共 5 个 hidden states
hidden = torch.tensor(
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]
)
metadata = _make_metadata([2, 3])
pooler = CLSPool()
out = pooler(hidden, metadata)
# 预期第一序列取 hidden[0],第二序列取 hidden[2]
expected = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
assert torch.equal(out, expected)
def test_rejects_partial_prefill(self):
# 当 num_scheduled_tokens < prompt_lens 时(partial prefill),
# CLSPool 应断言报错,因为第一个 token 可能尚未调度
hidden = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
metadata = _make_metadata([3], num_scheduled_tokens=[2])
pooler = CLSPool()
with pytest.raises(AssertionError, match="partial prefill"):
pooler(hidden, metadata)
评论区精华
Reviewer yewentao256 提出三点意见:
风险与影响
- 风险:该 PR 为纯测试变更,不修改任何生产代码,因此回归风险极低。主要风险在于测试覆盖可能仍不完整(如未测试所有可能的参数量组合),但已经覆盖了主要的正向和边界场景,能有效防止后续重构或优化引入的回归。
- 影响:对用户无直接影响。对团队而言,显著提升池化子系统的可测试性,为未来修改(如添加新池化方法或优化实现)提供安全网,降低引入 bug 的概率。测试代码也是一份可执行的文档,帮助新开发者理解池化行为。
- 风险标记:纯测试变更
关联脉络
参与讨论