Prhub

#44368 [ROCm][CI] Fix stale wvSplitK GEMM fallback test for N=5

原始 PR 作者 JartX 合并时间 2026-06-03 11:00 文件变更 1 提交数 1 评论 0 代码增减 +4 / -2

执行摘要

修复 ROCm wvSplitK GEMM 回退测试的边界值

PR #40687 改变了 wvSplitK 的触发边界,使得原有回退测试失效(断言回退但实际未回退)。此 PR 确保测试与代码逻辑保持一致,正确验证边界条件。

值得合并。虽然变更量小,但确保了测试与代码逻辑的一致性,避免了 CI 的虚假失败。

讨论亮点

无。仅 reviewer AndreasKaratzas 批准并评论 'Nice catch'。

实现拆解

  1. test_rocm_unquantized_gemm_gfx1x_n_gt_4_falls_back 重命名为 test_rocm_unquantized_gemm_gfx1x_n_gt_5_falls_back
  2. 将输入张量 x 的第一个维度从 5 改为 6,使得 n=6 > 5,确保回退路径被触发。
  3. 添加注释说明 wvSplitK skinny-GEMM 处理 n 在 [1,5] 范围,n>5 必须回退。
文件 模块 状态 重要度
tests/model_executor/layers/test_rocm_unquantized_gemm.py ROCm 测试 modified 5.3

关键符号

test_rocm_unquantized_gemm_gfx1x_n_gt_5_falls_back

关键源码片段

tests/model_executor/layers/test_rocm_unquantized_gemm.py test-coverage

修改了回退测试用例的输入维度和名称,以匹配 wvSplitK 边界条件变更。

# 位于 tests/model_executor/layers/test_rocm_unquantized_gemm.py
# 变更前函数名 : test_rocm_unquantized_gemm_gfx1x_n_gt_4_falls_back (n=5)
# 变更后函数名 : test_rocm_unquantized_gemm_gfx1x_n_gt_5_falls_back (n=6)def test_rocm_unquantized_gemm_gfx1x_n_gt_5_falls_back(monkeypatch):
    # wvSplitK skinny GEMM handles n in [1, 5] (see PR #40687); n > 5 must
    # fall back to torch.nn.functional.linear.
    x = torch.randn(6, 64, dtype=torch.float16) # 将 n 从 5 改为 6,确保超过边界
    weight = torch.randn(128, 64, dtype=torch.float16)
​
    monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False)
    monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True)
    monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: True)
    monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False)
    monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: False)
    monkeypatch.setattr(utils, "num_compute_units", lambda: 120)
​
    wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
    monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock)
    llmm1_mock = MagicMock(side_effect=lambda w, x_view, _: x_view @ w.t())
    monkeypatch.setattr(utils.ops, "LLMM1", llmm1_mock)
​
    out = utils.rocm_unquantized_gemm_impl(x, weight, None)
    ref = torch.nn.functional.linear(x, weight, None)
​
    # 断言 wvSplitK 和 LLMM1 均未被调用,确保回退到 torch 的线性操作
    wvsplitk_mock.assert_not_called()
    llmm1_mock.assert_not_called()
    assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

极低。仅修改测试用例的输入维度和名称,不涉及任何生产代码。

仅影响 ROCm 平台下 wvSplitK GEMM 的测试覆盖。修复后,CI 能正确验证回退边界,防止后续的回归。

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论