Prhub

#43818 [Misc] added unit tests for the core pooling methods

原始 PR 作者 taneem-ibrahim 合并时间 2026-05-29 22:40 文件变更 1 提交数 7 评论 5 代码增减 +499 / -0

执行摘要

为核心池化方法添加单元测试

PR 指出这些池化方法(CLSPool、LastPool、MeanPool、AllPool、StepPool)的单元测试覆盖率为零,且对池化子系统至关重要,因此添加测试以确保正确性并防止回归。

值得精读,尤其是 _make_pooling_cursor_make_metadata 辅助函数的设计——它们封装了复杂的 PoolingMetadata 构造逻辑,使测试代码简洁且易于扩展。同时,测试中对 partial prefill 和 chunked prefill 的边界处理也值得参考。

讨论亮点

Reviewer yewentao256 提出三点意见:

  • 前两条指出部分测试用例(如 test_extracts_first_token)已覆盖核心逻辑,额外的一些测试显得冗余,建议精简。
  • 第三条要求增加 enable_chunked_prefill=True 的测试场景,以确保池化方法在分块预填充下行为正确。
    作者积极响应,在后续提交中移除了冗余测试,并专门添加了分块预填充的测试覆盖。最终审核通过并合并。

实现拆解

  1. 编写测试辅助函数:在 tests/model_executor/layers/test_pooler_methods.py 中新增 _make_pooling_cursor_make_metadata,用于简化 PoolingCursorPoolingMetadata 的构造,支持自定义 prompt_lensnum_scheduled_tokensseq_lens 等参数,为后续测试提供基础设施。
  2. 序列级池化方法测试:为 CLSPoolLastPoolMeanPool 编写测试类,验证正向提取逻辑(如 extract first/last/mean token)以及异常场景(如 partial prefill 时抛出 AssertionError)。
  3. 词元级池化方法测试:为 AllPoolStepPool 编写测试类,验证多 token 输出和步进池化行为。
  4. 工厂函数测试:测试 get_seq_pooling_methodget_tok_pooling_method 的正确性及默认行为。
  5. 边界场景补充:根据 review 建议,增加 enable_chunked_prefill=True 的测试用例,覆盖调度 token 数小于实际 prompt 长度的场景。
  6. Activations 测试:在同一个文件中还包含了池化激活函数的测试,确保 PoolerActivation 的字符串到函数映射正确。
文件 模块 状态 重要度
tests/model_executor/layers/test_pooler_methods.py 池化层 added 7.74

关键符号

_make_pooling_cursor _make_metadata test_extracts_first_token test_rejects_partial_prefill test_extracts_last_token test_partial_prefill_extracts_last_scheduled

关键源码片段

tests/model_executor/layers/test_pooler_methods.py test-coverage

新增 499 行测试文件,全面覆盖池化方法和工厂函数。

# SPDX-License-Identifier: Apache-2.0import torch
import pytest
from vllm.model_executor.layers.pooler.seqwise.methods import CLSPool
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata, PoolingStates_CPU = torch.device("cpu")
​
​
def _make_pooling_cursor(
    prompt_lens: list[int],
    *,
    num_scheduled_tokens: list[int] | None = None,
    seq_lens: list[int] | None = None,
    device: torch.device = _CPU,
) -> PoolingCursor:
    """从每个序列的提示长度构建 PoolingCursor。"""
    # 确保输入为 PyTorch tensor,便于后续批处理
    prompt_lens_cpu = torch.tensor(prompt_lens, dtype=torch.long)
    if num_scheduled_tokens is None:
        num_scheduled_tokens_cpu = prompt_lens_cpu.clone()
    else:
        num_scheduled_tokens_cpu = torch.tensor(num_scheduled_tokens, dtype=torch.long)
    if seq_lens is None:
        seq_lens_cpu = prompt_lens_cpu.clone()
    else:
        seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.long)
​
    # 使用累积和计算每个序列的 token 起始 / 结束索引
    cumsum = torch.zeros(len(prompt_lens) + 1, dtype=torch.long, device=device)
    torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:])
​
    return PoolingCursor(
        first_token_indices_gpu=cumsum[: len(prompt_lens)].to(device),
        last_token_indices_gpu=(cumsum[1:] - 1).to(device),
        prompt_lens_cpu=prompt_lens_cpu,
        seq_lens_cpu=seq_lens_cpu,
        num_scheduled_tokens_cpu=num_scheduled_tokens_cpu,
    )
​
​
class TestCLSPool:
    """CLSPool 从每个序列提取第一个 token 作为池化输出。"""
​
    def test_extracts_first_token(self):
        # 构造 2 个序列,prompt_lens 分别为 2 和 3,共 5 个 hidden states
        hidden = torch.tensor(
            [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]
        )
        metadata = _make_metadata([2, 3])
        pooler = CLSPool()
        out = pooler(hidden, metadata)
        # 预期第一序列取 hidden[0],第二序列取 hidden[2]
        expected = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
        assert torch.equal(out, expected)
​
    def test_rejects_partial_prefill(self):
        # 当 num_scheduled_tokens < prompt_lens 时(partial prefill),
        # CLSPool 应断言报错,因为第一个 token 可能尚未调度
        hidden = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
        metadata = _make_metadata([3], num_scheduled_tokens=[2])
        pooler = CLSPool()
        with pytest.raises(AssertionError, match="partial prefill"):
            pooler(hidden, metadata)

评论区精华

冗余测试用例建议 测试

yewentao256 指出 test_extracts_first_token 已覆盖核心逻辑,额外的测试似乎冗余。

结论:作者移除了冗余测试,从后续提交推断已解决。 · 已解决

第二处冗余建议 测试

yewentao256 在另一位置指出类似冗余。

结论:同样被作者采纳处理。 · 已解决

chunked prefill 测试建议 测试

yewentao256 建议增加 enable_chunked_prefill=True 的测试场景。

结论:作者在后续提交中添加了 chunked prefill 测试覆盖。 · 已解决

风险与影响

该 PR 为纯测试变更,不修改任何生产代码,因此回归风险极低。主要风险在于测试覆盖可能仍不完整(如未测试所有可能的参数量组合),但已经覆盖了主要的正向和边界场景,能有效防止后续重构或优化引入的回归。

对用户无直接影响。对团队而言,显著提升池化子系统的可测试性,为未来修改(如添加新池化方法或优化实现)提供安全网,降低引入 bug 的概率。测试代码也是一份可执行的文档,帮助新开发者理解池化行为。

纯测试变更

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论