Prhub

#41813 [CPU][Zen] Route W8A8 and W4A16 linear inference through zentorch on AMD Zen CPUs

原始 PR 作者 aadwived 合并时间 2026-05-31 03:17 文件变更 7 提交数 20 评论 50 代码增减 +351 / -4

执行摘要

AMD Zen CPU 上 zentorch 加速 W8A8/W4A16 线性层

AMD Zen CPU 用户运行量化模型时,通用 oneDNN 核无法利用 zentorch 定制算子的性能优势。此 PR 从 RFC #41629 出发,为 vLLM 在 Zen CPU 上的部署提供通道级优化,同时通过 fallback 保证兼容性。

值得精读,尤其是 kernel 选择器 fallback 设计、平台检测函数抽象以及量化权重兼容性检查。建议在后续 PR 中考虑引入 PlatformEnum.ZEN 并增加端到端集成测试。

讨论亮点

W4A16 偏移量争议:Gemini Code Assist 认为 unpack 后 +8 导致数据损坏,贡献者解释 unpack_from_int32 输出 int4 范围 [-8,7],而 zentorch 期望 uint4 [0,15],偏移为必要转换,最终保留。
TorchAO 类型检查问题:Gemini 指出 isinstance(layer.weight, Int8Tensor) 因 Parameter 包装始终 False,贡献者测试确认实际路径正确;TorchAO 集成随后移出此 PR。
独立后端重构:mgoin 建议创建独立 zentorch.py 避免污染 cpu.py,团队采纳并实施。
测试条件统一:AndreasKaratzas 建议合并两个 skipif 装饰器,贡献者合并;SQNR 阈值从 30 dB 提升至 40 dB,参考 torchao 标准。
PlatformEnum 讨论:Andreas 建议引入 PlatformEnum.ZEN,团队同意作为后续 PR。

实现拆解

  1. 创建独立 zentorch 后端文件:在 scaled_mm/zentorch.pymixed_precision/zentorch.py 中分别新增 ZentorchInt8ScaledMMLinearKernelZentorchWNA16LinearKernel,继承自对应的 CPU 基类,覆写 is_supportedcan_implementprocess_weights_after_loadingapply_weights
  2. 实现平台和 op 可用性检查:通过 current_platform.is_zen_cpu()has_zentorch_op 函数(定义于 zentorch_utils.py)确保仅在 AMD Zen CPU 且 zentorch 库已安装时激活。
  3. 实现权重重排ZentorchWNA16LinearKernel_zentorch_woq_eligible 中检验压缩张量格式并排除带 g_idx 的层,然后通过 process_weights_after_loading 调用 torch.ops.zentorch.zentorch_woq_repack_weight 将 GPTQ 权重重排为 zentorch 布局;同时清理未使用的 zero_point 和 g_idx 属性。
  4. 集成到 kernel 选择器:在 linear/__init__.py 中导入新 kernel,置于 CPU 平台列表队首,确保优先选择;在 _LINEAR_BACKEND_KERNEL_MAP 中添加条目;在 register_linear_kernel 的公共 API 中添加名称。
  5. 更新依赖和清理:在 setup.py 中将 zentorch 依赖固定为 zentorch==5.2.1,移除过时的 weekly 注释。
文件 模块 状态 重要度
vllm/model_executor/kernels/linear/mixed_precision/zentorch.py W4A16 核 added 9.28
vllm/model_executor/kernels/linear/scaled_mm/zentorch.py W8A8 核 added 8.85
vllm/model_executor/kernels/linear/zentorch_utils.py 工具函数 added 7.03
vllm/model_executor/kernels/linear/__init__.py 核注册 modified 5.75
vllm/model_executor/kernels/linear/mixed_precision/__init__.py 核注册 modified 4.48
vllm/model_executor/kernels/linear/scaled_mm/__init__.py 核注册 modified 4.48
setup.py 构建配置 modified 4.67

关键符号

has_zentorch_op ZentorchInt8ScaledMMLinearKernel.is_supported ZentorchInt8ScaledMMLinearKernel.can_implement ZentorchInt8ScaledMMLinearKernel.process_weights_after_loading ZentorchInt8ScaledMMLinearKernel.apply_weights ZentorchWNA16LinearKernel.can_implement ZentorchWNA16LinearKernel._zentorch_woq_eligible ZentorchWNA16LinearKernel.process_weights_after_loading ZentorchWNA16LinearKernel.apply_weights

关键源码片段

vllm/model_executor/kernels/linear/mixed_precision/zentorch.py core-logic

新增的 zentorch W4A16 量化线性核后端,包含资格检查、权重重排和推理调用,是此 PR 的核心实现之一。

class ZentorchWNA16LinearKernel(CPUWNA16LinearKernel):
    """W4A16 GPTQ kernel backed by ``torch.ops.zentorch.zentorch_woq_linear``."""
​
    @classmethod
    def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
        ok, reason = super().can_implement(c)
        if not ok:
            return ok, reason
        # 检查平台:仅 AMD Zen CPU
        if not current_platform.is_zen_cpu():
            return False, "ZentorchWNA16 requires an AMD Zen CPU."
        # 检查 zentorch op 注册
        if not has_zentorch_op(["zentorch_woq_repack_weight", "zentorch_woq_linear"]):
            return (False, "torch.ops.zentorch.{...} are not registered.")
        # 不支持激活重排(g_idx)
        if c.has_g_idx:
            return False, "ZentorchWNA16 does not support activation re-ordering."
        return True, None
​
    def _zentorch_woq_eligible(self, layer: torch.nn.Module) -> bool:
        # 排除带 g_idx 的层
        if (self.w_gidx_name is not None
            and getattr(layer, self.w_gidx_name, None) is not None) \
           or getattr(self.config, "has_g_idx", False):
            return False
        # 必须存在权重和尺度
        weight_packed = getattr(layer, self.w_q_name, None)
        weight_scale = getattr(layer, self.w_s_name, None)
        if weight_packed is None or weight_scale is None:
            return False
        bits = self.config.weight_type.mantissa
        pack_factor = torch.iinfo(weight_packed.dtype).bits // bits
        # 仅支持 4-bit 打包(每个 int32 存 8 个值)
        if pack_factor != 8:
            return False
        # 仅支持 compressed-tensors 格式(input_dim == packed_dim == 1)
        in_dim = getattr(weight_packed, "input_dim", None)
        pk_dim = getattr(weight_packed, "packed_dim", None)
        if not (in_dim == pk_dim == 1):
            return False
        if weight_packed.dim() != 2 or weight_scale.dim() != 2:
            return False
        in_features = weight_packed.shape[1] * 8
        num_groups = weight_scale.shape[1]
        return num_groups > 0 and in_features % num_groups == 0
vllm/model_executor/kernels/linear/scaled_mm/zentorch.py core-logic

新增的 zentorch W8A8 动态 symmetric 量化线性核后端,包含 is_supported/can_implement 检查、权重格式转换和推理调用。

class ZentorchInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
    @classmethod
    def is_supported(cls, compute_capability=None) -> tuple[bool, str | None]:
        if not current_platform.is_cpu():
            return False, "requires CPU."
        if not current_platform.is_zen_cpu():
            return False, "requires AMD Zen CPU."
        if not has_zentorch_op(["zentorch_dynamic_qlinear"]):
            return False, "torch.ops.zentorch.zentorch_dynamic_qlinear is not registered."
        return True, None
​
    @classmethod
    def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
        if c.is_static_input_scheme:
            return False, "requires dynamic activation quantization."
        if not c.input_symmetric:
            return False, "requires symmetric activation quantization."
        if not c.is_channelwise:
            return False, "requires per-channel weight quantization."
        return True, None
​
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        w_q_name, w_s_name, _, _, _ = self.layer_param_names
        weight = getattr(layer, w_q_name)
        n = weight.shape[0]
        # 权重保持 [N, K] 布局且 contiguous
        replace_parameter(layer, w_q_name,
            torch.nn.Parameter(weight.data.contiguous(), requires_grad=False))
        # 权重尺度转换为 bf16 并调整为 (N,) 形状
        weight_scale = getattr(layer, w_s_name)
        ws = weight_scale.data
        if ws.dim() == 2 and ws.shape[-1] == 1:
            ws = ws.squeeze(-1)
        ws = ws.to(torch.bfloat16).contiguous()
        assert ws.shape == (n,), f"expected ({n},), got {tuple(ws.shape)}"
        replace_parameter(layer, w_s_name,
            torch.nn.Parameter(ws, requires_grad=False))
        logger.info_once("[zen_cpu] Using zentorch_dynamic_qlinear for W8A8 (dynamic-symmetric)")
​
    def apply_weights(self, layer, x, bias=None) -> torch.Tensor:
        w_q_name, w_s_name, _, _, _ = self.layer_param_names
        return torch.ops.zentorch.zentorch_dynamic_qlinear(
            x, getattr(layer, w_q_name), getattr(layer, w_s_name),
            bias, zentorch_op_name="zentorch::zentorch_dynamic_qlinear")
vllm/model_executor/kernels/linear/zentorch_utils.py utility

提供 has_zentorch_op 函数,统一平台检测和 op 可用性检查,被所有 zentorch 后端调用。

from __future__ import annotations
import torch
from vllm.platforms import current_platform__all__ = ["has_zentorch_op"]def has_zentorch_op(op_names: list[str]) -> bool:
    """返回当运行在 Zen CPU 且所有命名 op 已注册时为 ``True``。"""
    if not op_names:
        raise ValueError("has_zentorch_op requires at least one op name")
    # 先检查平台:仅 AMD Zen CPU 才继续
    if not current_platform.is_zen_cpu():
        return False
    # 检查 torch.ops.zentorch 命名空间
    ns = getattr(torch.ops, "zentorch", None)
    if ns is None:
        return False
    # 验证所有请求的 op 都已注册
    return all(hasattr(ns, op_name) for op_name in op_names)

评论区精华

W4A16 权重偏移量 +8 的正确性 正确性

Gemini Code Assist 认为 unpack 后 +8 和 clamp 会损坏数据,建议移除。Harshaladhav 解释 unpack_from_int32 输出 int4 范围,而 zentorch 期望 uint4,偏移是必要的转换。

结论:保留偏移,贡献者验证通过。 · 已解决

TorchAO 集成中 isinstance 检查失效 正确性

Gemini 指出 isinstance(layer.weight, Int8Tensor) 对于 Parameter 始终 False。Aadwived 回应实际路径测试通过,且 TorchAO 部分已移出此 PR。

结论:TorchAO 集成移至后续 PR。 · 已解决

将 zentorch 逻辑拆分为独立后端 设计

mgoin 建议创建独立的 zentorch.py 后端,避免污染 cpu.py。Aadwived 同意并实施。

结论:已创建独立后端文件。 · 已解决

测试跳过条件和 SQNR 阈值 测试

Andreas 建议合并两个 skipif 装饰器并提高 SQNR 阈值。Aadwived 统一了 skipif,并将 SQNR 从 30 dB 提升至 40 dB,引用 torchao 标准。

结论:已更新测试。 · 已解决

未来引入 PlatformEnum.ZEN 设计

Andreas 建议长远使用独立平台枚举,Aadwived 同意但推迟至后续 PR。

结论:暂时保留在 CPU 下,后续 PR 引入 ZEN。 · 已解决

风险与影响

1) 权重布局假设风险ZentorchWNA16LinearKernel 假设 packed weight 为 compressed-tensors GPTQ 格式(input_dim == packed_dim == 1),其他格式(如 Marlin)会被 _zentorch_woq_eligible 排除并回退,无数据损坏风险。
2) zentorch 依赖风险:未安装时自动回退,但用户可能期望加速而实际未使用;缺少日志提示。
3) PerTensor 激活量化升级:当模型使用 PerTensor 时,代码静默升级为 PerRow,输出品质可能变化,但不会出错。
4) 平台枚举耦合PlatformEnum.CPU 同时包含 oneDNN 和 zentorch,未来引入 PlatformEnum.ZEN 可解耦。
5) 测试覆盖缺口:缺少端到端模型推理测试,仅单元测试覆盖 dispatch。

  • 用户影响:AMD Zen CPU 用户在运行 GPTQ(W4A16)或 LLM-Compressor(W8A8)量化模型时自动获得加速;其他平台无影响。
  • 系统影响:新增可选依赖 zentorch,非 Zen 平台无额外开销。
  • 团队影响:需维护独立后端,跟踪 zentorch 版本;kernel 选择器复杂度略有增加。
平台假设耦合 依赖条件性降级 布局假设有限 无端到端测试

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论