Prhub

#40687 [ROCm][Perf] Support N=5 in wvSplitK skinny GEMM kernels for speculative decoding

原始 PR 作者 mgehre-amd 合并时间 2026-05-29 00:28 文件变更 2 提交数 2 评论 3 代码增减 +7 / -1

执行摘要

ROCm 瘦 GEMM 内核支持 N=5,加速推测解码验证

在 AMD ROCm 平台上使用推测解码(如 Eagle3, num_speculative_tokens=4)时,目标模型验证阶段的 batch size 为 5(1 个原始 token + 4 个推测 token)。wvSplitK 瘦 GEMM 内核原本仅支持 N<=4,导致验证步骤 fallback 到 torch.nn.functional.linear(hipBLAS),性能不佳。PR 旨在通过扩展内核支持,使验证步骤也能使用高性能 HIP 内核,从而提升推测解码的整体性能。

值得合并的针对性性能优化。建议未来考虑自动化特化更多 N 值的方法,以减少手动添加 case 的工作量和编译时间。同时可关注 custom op 的优化机会。

讨论亮点

审核者 tjtanaa 询问此改动是否仅针对 N=5 有提升,以及能否进一步增加 N 的范围(如 N<16)。作者回应称,增加 N 的特殊化会显著增加编译时间(因涉及其他模板参数),目前未验证更高 N 的性能。另外 tjtanaa 提到 torch.nn.linear 被包装在 custom op 中,未受 torch.compile 优化,需要考虑 custom ops 的优化。

实现拆解

  1. 修改 dispatch 阈值:在 vllm/model_executor/layers/utils.pyrocm_unquantized_gemm_impl 函数中,将 wvSplitK 内核的启用条件从 0 < n <= 4 改为 0 < n <= 5,使 batch size 为 5 的验证操作也能使用该内核。
  2. 添加内核 tile 配置:在 csrc/rocm/skinny_gemms.cuwvSplitK 函数 switch 语句中添加 case 5: 分支,根据是否使用 wave32 模式设置相应的 tile 配置(32x16 或 64x16),确保内核能正确处理 N=5 的矩阵乘法。
文件 模块 状态 重要度
vllm/model_executor/layers/utils.py GEMM 调度器 modified 5.5
csrc/rocm/skinny_gemms.cu HIP 内核 modified 3.81

关键符号

rocm_unquantized_gemm_impl wvSplitK

关键源码片段

vllm/model_executor/layers/utils.py data-contract

修改了 wvSplitK 内核的 dispatch 条件,将 N 上限从 4 提升到 5,是 PR 性能提升的关键决策点。

# vllm/model_executor/layers/utils.py ( 修改 dispatch 条件 )
def rocm_unquantized_gemm_impl(...):
    # ... 前面的条件判断 ...
​
    use_skinny = (
        envs.VLLM_ROCM_USE_SKINNY_GEMM
        and (on_gfx9() or on_gfx1x())
        and x.dtype in [torch.float16, torch.bfloat16]
        and k % 8 == 0
    )
​
    if use_skinny:
        x_view = x.reshape(-1, x.size(-1))
        # 原条件 0 < n <= 4,现改为 0 < n <= 5
        # 使得 Eagle3 验证(batch size 5)也能使用 wvSplitK 内核
        if m > 8 and 0 < n <= 5:
            cu_count = num_compute_units()
            out = ops.wvSplitK(weight, x_view, cu_count, bias)
            return out.reshape(*x.shape[:-1], weight.shape[0])
        elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
            out = ops.LLMM1(weight, x_view, 4)
            return out.reshape(*x.shape[:-1], weight.shape[0])
​
    # fallback 到其他内核或 torch.nn.functional.linear
csrc/rocm/skinny_gemms.cu core-logic

添加了 N=5 的 tile 配置,使内核能处理新的 batch size。

// csrc/rocm/skinny_gemms.cu (wvSplitK 函数中新增 case 5)
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, ...) {
    // ... 前面的 switch 处理 N=1..4 ...
    case 5:
        // 新增 N=5 的 tile 配置
        // use_wave32 为 true 时使用 32x16 tile
        // 否则使用 64x16 tile
        if (use_wave32)
            WVSPLIT_TILE_CFG(32, 16, sYT, 5)
        else
            WVSPLIT_TILE_CFG(64, 16, sYT, 5)
        break;
    default:
        throw std::runtime_error(
            "Unsupported N value: " + std::to_string(M_in) + "," + ...);
    // ... 后续处理 ...
}

评论区精华

是否可进一步扩展 N 支持范围 question

tjtanaa 询问能否证明更多场景的提升,或是否可增加 N<16 的支持。作者回应增加 N 特化会显著增加编译时间,尚未验证更高 N。

结论:当前仅支持到 N=5,未来如需更大 N 需权衡编译时间与收益。 · 已解决

风险与影响

风险较低。仅修改了 dispatch 条件(n<=5)和添加了一个 switch case,不影响已有功能。主要风险在于 kernel 编译时间可能因新增 specialization 而略有增加,但已在讨论中确认可接受。此外,由于未对 N>5 的场景进行优化,若未来有更大 batch 的验证需求,仍需 fallback。

直接影响 ROCm 平台上使用推测解码(Eagle3 等)的 Qwen3-8B 等模型,验证阶段 batch size 为 5 时可获得 12-14% 的性能提升。对其他模型或非推测解码场景无影响。改动仅 7 行代码,无配置或接口变更,无兼容性问题。

仅限 ROCm 平台 增加编译时间

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论