Prhub

#41394 [Kernel][ROCm] Native W4A16 kernel for AMD RDNA3 (gfx1100) — fp16 + bf16

原始 PR 作者 JartX 合并时间 2026-05-29 19:04 文件变更 13 提交数 48 评论 58 代码增减 +3846 / -0

执行摘要

RDNA3 原生 W4A16 内核,fp16/bf16 性能飞跃

在 RDNA3 上,唯一快速的 W4A16 路径是 ExLlama(仅 fp16),将 bf16 训练模型强制转为 fp16 会导致数值不稳定(inf/NaN),而 Triton W4A16 bf16 速度仅为 ExLlama fp16 的 1/3。本 PR 消除了这种折衷:提供原生 bf16 W4A16 路径,速度 205-345 tk/s(Triton bf16 的 2.5-4.2 倍),同时保留 fp16 路径并超越 ExLlama。

值得所有 AMD ROCm 平台开发者精读,尤其关注 C++ HIP 内核实现(标量 dot-product vs WMMA 调度、LDS 双缓冲、LLVM 编译器 bug 变通)和 Python 监听门控模式。对于量化推理优化者,该 PR 提供了一个针对消费级 GPU 极致优化的参考案例。

讨论亮点
  • 维护担忧(AndreasKaratzas):CI 没有 RDNA3 设备,回归可能无法检测,是否应该上游。JartX 回应代码隔离在专属模块且不影响其他路径;最终 tjtanaa 批准。
  • gfx1151 兼容性(gshtras):gfx1151(Strix Halo)是否应被包含?JartX 随后将内核严格门控到 gfx1100(commit 35d575a),并添加 on_gfx1151() 辅助函数供混合内核使用。
  • act-order 排列缺失(depthfirst-app):V7/V8 WMMA 内核忽略 b_q_perm,导致 act-order 模型在 prefill 阶段输出错误。JartX 在 commit 649688c 中修复,将 act-order 请求路由到 V5(该 tile 处理排列)。
  • 文件位置与编译门控(mgehre-amd):建议将内核从 csrc/libtorch_stable/ 移到 csrc/quantization/gptq/_rocm_C 扩展以解决 NVCC 兼容性。JartX 随后重构(commit 7e66513)将文件移至 csrc/rocm/,并在构建系统中正确隔离。
  • 用户反馈(wizardeur, gesong2077):在实际编码工作负载中验证了 PR 的稳定性和性能提升,社区期望合并。

实现拆解

  1. 平台门控与架构检测:在 vllm/platforms/rocm.py 中新增 on_gfx11on_gfx1100on_gfx1151 函数及模块级布尔标志(如 _ON_GFX11),通过 amdsmi 解析 GCN 架构字符串,避免 CUDA 初始化。内核的 can_implement() 利用 on_gfx1100() 精确限制在目标硬件。

  2. Python 内核封装类:新增 vllm/model_executor/kernels/linear/mixed_precision/rdna3_w4a16.py,实现 RDNA3W4A16LinearKernel 继承自 MPLinearKernel。包含 process_weights_after_loading(处理权重重排、零点点合成)、apply_weights(调用 HIP 操作)等方法,并注册到内核优先级列表(优先于 TritonW4A16LinearKernel)。

  3. C++ HIP 内核实现:在 csrc/rocm/ 下添加两个 CU 文件:

    • q_gemm_rdna3.cu:标量 dot-product 内核,用于 decode(M<16)和 fp16 路径,使用 v_dot2_f32_f16v_dot2_f32_bf16(通过 union 指针绕过 LLVM 折叠 bug)。
    • q_gemm_rdna3_wmma.cu:WMMA 矩阵乘法内核,用于 prefill(M≥16),支持 v_wmma_f32_16x16x16 指令,包含多个 tile 版本(16x16_1w 至 128x64_k32),通过 K 分割和 LDS 双缓冲优化。
      共享反量化工具头文件 qdq_4_rdna3.cuh。操作通过 TORCH_LIBRARY_EXPAND 注册到 _rocm_C 扩展。
  4. Python 操作绑定与编译集成:在 vllm/_custom_ops.py 中定义 gptq_gemm_rdna3 函数和两个 fake 操作(用于 torch.compile),并在 csrc/rocm/torch_bindings.cppops.h 中注册。内核文件通过 CMake 构建系统编译。

  5. 测试与验证:新增两个测试文件:

    • test_rdna3_w4a16.py:端到端正确性测试,对比 fp32 参考(dequant+matmul)验证 kernel 输出,覆盖 fp16/bf16、有无零点、多种形状(M=1,16,128 等)。
    • test_rdna3_w4a16_selection.py:内核选择测试,验证在 gfx1100 上 choose_mp_linear_kernel 优先返回 RDNA3 内核,以及 can_implement 正确拒绝不支持的配置。
文件 模块 状态 重要度
vllm/model_executor/kernels/linear/mixed_precision/rdna3_w4a16.py 量化内核层 added 9.36
tests/kernels/quantization/test_rdna3_w4a16.py 测试 added 8.24
vllm/_custom_ops.py 操作绑定 modified 7.84

关键符号

RDNA3W4A16LinearKernel.can_implement RDNA3W4A16LinearKernel.process_weights_after_loading RDNA3W4A16LinearKernel.apply_weights vllm._custom_ops.gptq_gemm_rdna3 vllm.platforms.rocm.on_gfx1100 vllm.platforms.rocm.on_gfx11 vllm.platforms.rocm.on_gfx1151 tests.kernels.quantization.test_rdna3_w4a16._reference tests.kernels.quantization.test_rdna3_w4a16.test_rdna3_w4a16_matches_reference tests.kernels.quantization.test_rdna3_w4a16.test_rdna3_w4a16_bias tests.kernels.quantization.test_rdna3_w4a16_selection.test_selection_prefers_rdna3 tests.kernels.quantization.test_rdna3_w4a16_selection.test_can_implement

关键源码片段

vllm/model_executor/kernels/linear/mixed_precision/rdna3_w4a16.py core-logic

核心 Python 封装类,定义 `RDNA3W4A16LinearKernel`,包含权重预处理、`apply_weights` 以及内核门控逻辑(`can_implement`)。是内核与 vLLM 推理框架的接口。

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""W4A16 GPTQ kernel for AMD RDNA3 (gfx1100) — fp16 + bf16."""import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import pack_quantized_values_into_int32
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
​
​
class RDNA3W4A16LinearKernel(MPLinearKernel):
    SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8]
​
    @classmethod
    def get_min_capability(cls) -> int:
        return 60 # ROCm 门控通过 can_implement 内的 on_gfx1100() 实现
​
    @classmethod
    def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
        # 仅在 ROCm 且 gfx1100 上激活
        if not current_platform.is_rocm():
            return False, "RDNA3 W4A16 kernel is ROCm-only"
        from vllm.platforms.rocm import on_gfx1100
        if not on_gfx1100():
            return False, "RDNA3 W4A16 kernel requires gfx1100"
        # 检查自定义操作是否已编译
        if not (hasattr(torch.ops, "_rocm_C") and hasattr(torch.ops._rocm_C, "gptq_gemm_rdna3")):
            return False, "torch.ops._rocm_C.gptq_gemm_rdna3 missing — rebuild C++ extension"
        if c.act_type not in (torch.float16, torch.bfloat16):
            return False, "RDNA3 W4A16 kernel only supports fp16 and bf16"
        if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
            return False, f"Quant type ({c.weight_type}) not supported"
        if c.group_size <= 0:
            return False, "does not support channelwise quantization"
        if c.full_weight_shape[0] % c.group_size != 0:
            return False, "Group size does not divide K"
        # N 必须是 8 的倍数,满足 qzeros 打包对齐
        if c.partition_weight_shape[1] % 8 != 0:
            return False, "Output features must be a multiple of 8"
        if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
            return False, "Act-order with TP-partitioned not supported"
        return True, None
​
    def process_weights_after_loading(self, layer: torch.nn.Module):
        """权重预处理:合成零点(如果不存在)、打乱布局,与 ExLlama 布局兼容。"""
        c = self.config
        device = getattr(layer, self.w_q_name).device
        if not c.zero_points:
            self.w_zp_name = "qzeros"
            groups = c.partition_weight_shape[0] // c.group_size
            out_features = c.partition_weight_shape[1]
            if c.weight_type.has_bias():
                # GPTQv1 约定:存储零点 = bias - 1,内核运行时加 1
                zeros = torch.full((groups, out_features), c.weight_type.bias - 1,
                                   dtype=torch.int32, device=device)
            else:
                raise NotImplementedError("zero-bias 4-bit quant requires explicit zero points")
            zeros = pack_quantized_values_into_int32(zeros, c.weight_type, packed_dim=1)
            setattr(layer, self.w_zp_name, zeros)
        # 调用基类处理权重重排
        super().process_weights_after_loading(layer)
​
    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        """执行 W4A16 矩阵乘法。"""
        w_q = getattr(layer, self.w_q_name)
        w_s = getattr(layer, self.w_s_name)
        w_zp = getattr(layer, self.w_zp_name, None) if self.config.zero_points else None
        g_idx = getattr(layer, self.w_g_idx_name, None)
        return ops.gptq_gemm_rdna3(x, w_q, w_zp, w_s, g_idx, use_v2_format=self.config.zero_points)
tests/kernels/quantization/test_rdna3_w4a16.py test-coverage

主要正确性测试文件,覆盖 fp16/bf16、有无零点、多种 M/K/N 形状,通过 fp32 参考对比验证内核输出。

# 核心参考实现,用于对比 kernel 输出
# 正确性测试的核心:先手动 dequant,然后执行 fp32 matmuldef _reference(
    x_mk: torch.Tensor,
    q_int4_kn: torch.Tensor,
    scales_gn: torch.Tensor,
    zeros_gn: torch.Tensor | None, # None 表示对称路径,kernel 合成零点 = 7
    group_size: int,
    bias: torch.Tensor | None,
) -> torch.Tensor:
    K, N = q_int4_kn.shape
    s_full = scales_gn.repeat_interleave(group_size, dim=0).to(torch.float32) # [K, N]
    if zeros_gn is None:
        # kernel 内部:存储零点 = 7,实际零点 = 7 + 1 = 8(weight_type.bias)
        z_full = torch.full((K, N), float(WEIGHT_TYPE.bias), device=x_mk.device, dtype=torch.float32)
    else:
        # 应用 GPTQv1 +1 约定
        z_full = (zeros_gn + 1).repeat_interleave(group_size, dim=0).to(torch.float32)
    w_fp = (q_int4_kn.to(torch.float32) - z_full) * s_full # [K, N] 反量化
    out = x_mk.to(torch.float32) @ w_fp # [M, N]
    if bias is not None:
        out = out + bias.to(torch.float32)
    return out.to(x_mk.dtype)@gfx1100_only
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("M,K,N,group_size", [
    (1, 4096, 256, 128),
    (16, 4096, 256, 128),
    (128,4096, 256, 128),
    (1, 5120, 128, 128),
    (1, 11008,128, 128),
])
def test_rdna3_w4a16_matches_reference(M, K, N, group_size, has_zp, dtype):
    """端到端测试:构建 GPTQ 层 -> 预处理权重 -> kernel 推理 -> 与参考对比"""
    layer, ctx = _build_layer(M, K, N, group_size, has_zp, dtype)
    kernel = RDNA3W4A16LinearKernel(ctx)
    kernel.process_weights_after_loading(layer)
    x = torch.randn(M, K, dtype=dtype, device=device)
    out = kernel.apply_weights(layer, x)
    ref = _reference(x, ctx["q_int4_kn"], ctx["scales_gn"], ctx["zeros_gn"], group_size, bias=ctx.get("bias"))
    # 允许 fp16/bf16 精度差异,rtol 设为 1e-2, atol 1e-2
    assert torch.allclose(out, ref, rtol=1e-2, atol=1e-2), f"Mismatch for M={M}, K={K}, N={N}"
vllm/_custom_ops.py core-logic

注册新的 ROCm 操作 `gptq_gemm_rdna3` 及其 fake 实现,是 Python 与 C++ 内核的桥梁。

# 在 vllm/_custom_ops.py 中(原有 gptq_gemm 函数之后)def gptq_gemm_rdna3(
    a: torch.Tensor,
    b_q_weight: torch.Tensor,
    b_qzeros: torch.Tensor,
    b_scales: torch.Tensor,
    b_g_idx: torch.Tensor,
    use_v2_format: bool,
) -> torch.Tensor:
    """封装 _rocm_C 操作,支持 torch.compile 动态 shape。"""
    return torch.ops._rocm_C.gptq_gemm_rdna3(
        a, b_q_weight, b_qzeros, b_scales, b_g_idx, use_v2_format
    )# 如果操作存在,注册 fake 实现(用于 torch.compile shape 推导)
if hasattr(torch.ops, "_rocm_C") and hasattr(torch.ops._rocm_C, "gptq_gemm_rdna3"):
​
    @register_fake("_rocm_C::gptq_gemm_rdna3")
    def _gptq_gemm_rdna3_fake(
        a: torch.Tensor,
        b_q_weight: torch.Tensor,
        b_qzeros: torch.Tensor,
        b_scales: torch.Tensor,
        b_g_idx: torch.Tensor,
        use_v2_format: bool,
    ) -> torch.Tensor:
        return torch.empty(
            (a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device
        )

评论区精华

维护责任与 CI 覆盖 设计

AndreasKaratzas 担心 CI 没有 RDNA3 设备,导致回归无法检测,质疑是否应上游。tjtanaa 最终批准,认为与 CDNA 路径完全隔离,不影响 AMD CI。

结论:代码隔离在专属模块,且仅 gfx1100 激活;团队接受维护风险。 · 已解决

gfx1151 (Strix Halo) 兼容性 设计

gshtras 指出 gfx115x 不是 Radeon 但属于 gfx11,是否应被内核覆盖。JartX 随后将内核严格门控到 gfx1100,并添加 on_gfx1151() 供混合内核使用。

结论:内核仅限 gfx1100;gfx1151 留给混合内核(PR #40977)。 · 已解决

WMMA 内核忽略 act-order 排列 正确性

depthfirst-app 发现 V7/V8 WMMA 内核始终顺序读取激活,忽略 b_q_perm,导致 act-order 模型 prefill 输出错误。

结论:JartX 提交修复(commit 649688c):将 act-order 请求路由到 V5 tile(该 tile 正确应用排列)。 · 已解决

文件位置与编译门控 设计

mgehre-amd 指出内核文件应放在 csrc/rocm/ 并加入 _rocm_C 扩展以避免 NVCC 编译失败(使用 AMD 内置函数)。JartX 随后执行重构(commit 7e66513)。

结论:文件移至 csrc/rocm/ 并正确门控,确保 NVCC 下安全存根。 · 已解决

标量内核 bf16 M=1 跳过 act-order 正确性

depthfirst-app 指出 M=1 bf16 路径使用 `USE_LDS_A=false`,未应用 b_q_perm,导致 decode 阶段 act-order 模型输出错误。

结论:JartX 在后续 commits 中修复:为 bf16 M=1 路径恢复 LDS staging(或通过其他方式应用排列)。最终提交 05f48afa 明确应用 act-order 排列。 · 已解决

使用 constexpr 替换宏 style

gemini-code-assist 建议使用 constexpr 常量替代 #define 宏以提高类型安全性。

结论:未被采纳(可能因 HIP 内核风格或兼容性原因),评论已解决。 · 已解决

风险与影响

  • 回归风险:内核仅在 gfx1100 上激活,但 CI 缺少 RDNA3 设备,任何回归(如编译失败、性能下降)在合并前可能无法被自动检测。
  • 正确性风险:早期版本存在 act-order 排列未应用和 bf16 M=1 路径直接读取全局内存的问题,虽已在最终提交中修复,但若未来代码重构导致类似逻辑退化,可能重新引入。此外,C++ 操作缺少对 N%8 的 TORCH_CHECK(依赖 Python 层),存在潜在的内存越界。
  • 编译风险:内核使用 ROCm 内置函数(如 __builtin_amdgcn_fdot2、WMMA 内置),需确保在非 ROCm 编译环境下安全存根。已通过 _rocm_C 扩展和 #if defined(USE_ROCM) 门控。
  • 性能回退:多版本 WMMA 内核(V1-V8)的分派逻辑复杂,可能在某些边角情况(如 K 余数处理)触发次优路径或原子 CAS 竞争。测试已覆盖常见形状,但不排除发布后出现性能退化。

对 AMD RDNA3 (gfx1100) 用户有巨大正面影响:W4A16 推理延迟降低 2.5-4.2 倍,支持 bf16 保留模型原生精度;fp16 路径同样超越 ExLlama。对其他 AMD GPU(CDNA、RDNA4)无影响——自动回退到 Triton 内核。团队需承担维护责任,但代码隔离在 rdna3_w4a16.pycsrc/rocm/ 下,不干扰核心路径。新增 ~3800 行代码,其中 2200 行为 C++ 内核,测试覆盖充分。

缺少 RDNA3 CI 覆盖 act-order 边角情形已修复 编译门控依赖特定架构 无回归自动检测机制

关联 Issue

#6074 [Issue]: RCCL 2.27.7 / ROCm 7.2.1 dual-GPU collectives fail on 2x Radeon RX 7900 XTX

完整报告

参与讨论