Prhub

#40033 [NVFP4][Hopper/AMD Instinct] Add Triton kernels for NVFP4 dequantization and QDQ emulation

原始 PR 作者 fxmarty-amd 合并时间 2026-05-01 05:35 文件变更 4 提交数 10 评论 16 代码增减 +633 / -4

执行摘要

添加 Triton 内核加速 NVFP4 反量化和 QDQ 模拟

NVFP4 量化模拟在 AMD MI3xx、H100 等不支持原生 FP4 的设备上需要反量化和 QDQ 操作。原始的 PyTorch 参考实现因发射大量独立内核导致 HSA 崩溃(见 issue #180341、#5071)。本 PR 通过提供融合的 Triton 内核来显著加速模拟并规避内核堆积问题。

值得精读:

  • 学习 Triton 内核优化技巧:二进制树 E2M1 查找、2D tile 批处理、interleave 合并写。
  • 理解设备间功能兼容性处理:通过 current_platform.is_cuda_alike() 动态切换实现。
  • 关注社区反馈中对类型安全的关注,建议合并后进一步放宽 global_scale 类型以支持 float。
讨论亮点
  • BowenBao 询问端到端性能:作者提供了 MI300X 和 MI355X 上 Qwen/DeepSeek 模型的完整吞吐数据,显示 QDQ + Activation Skipped 模式性能最佳,Triton 融合内核相比 unfused FP32 有约 6-8 倍加速。
  • gemini-code-assist 指出的类型安全风险global_scale 类型从 torch.Tensor | float 收紧为 torch.Tensor,可能破坏传入 float 的调用者。作者回应“此改动从 #35737 移植而来,是安全的”,但未在 PR 中进一步放宽。
  • BowenBao 质疑 ref_nvfp4_quant_dequant 返回 tuple 的必要性:作者随后在 commit 662fa67 中调整了调用处,使其不再直接返回 tuple,但函数签名仍保留 tuple 返回类型。讨论已解决。
  • kylesayrs 对查找表设备迁移模式的评论:认为在 EMULATION backend 中提前搬运 kE2M1ToFloat_handle 的模式虽然奇怪,但为避免 CUDA graph 限制,这是最佳折中。作者计划后续 PR 中改进。
  • kylesayrs 建议添加 block_size 整除性断言:作者在 7c885b37 中采纳,添加了 x_k % block_size 断言。

实现拆解

  1. nvfp4_emulation_utils.py 中添加 Triton JIT 内核

    • _e2m1_inline:使用二进制树分解实现内联 E2M1 查找,避免 LUT 导致的 LDS bank 冲突。
    • _dequantize_nvfp4_kernel:处理反量化,支持 2D/3D 输入,使用 tile 批处理和 interleave 存储以合并写操作。
    • _nvfp4_quant_dequant_kernel:融合量化-反量化通道,匹配 Python cast_to_fp4 的阈值。
    • _triton_dequantize_nvfp4_triton_nvfp4_quant_dequant:调度函数,根据输入动态选择内核参数。
  2. 修改 dequantize_to_dtype 入口:条件判断 current_platform.is_cuda_alike(),在 CUDA/ROCm 上自动切换到 Triton 内核,否则回退到原始 PyTorch 实现。同时调整 global_scale 类型注解为 torch.Tensor(之前允许 float),以配合内核调用约定。

  3. 调整 MoE 模拟路径

    • fused_moe/oracle/nvfp4.py 中,将 E2M1 查找表提前搬运至设备(kE2M1ToFloat_handle.val.to(...)),避免 CUDA graph 捕获时调用 .to(device)
    • fused_moe/utils.py 中,ref_nvfp4_quant_dequant 的返回值改为先解包再返回 (A, None),以兼容下游预期。
  4. 新增测试文件 tests/kernels/quantization/test_nvfp4_emulation.py

    • 使用真实 NVFP4 权重(nvidia/Qwen3-30B-A3B-NVFP4)验证 2D/3D 多种形状的正确性(bitwise identical)。
    • 通过 monkeypatch 禁用 Triton 路径来对比 PyTorch 参考实现。
    • 内嵌性能基准测试,量化加速比。
文件 模块 状态 重要度
vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py 量化工具 modified 9.05
tests/kernels/quantization/test_nvfp4_emulation.py 测试 added 7.76
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py MoE 调度 modified 5.61
vllm/model_executor/layers/fused_moe/utils.py MoE 工具 modified 4.63

关键符号

_e2m1_inline _dequantize_nvfp4_kernel _e2m1_lookup _round_to_fp4 _nvfp4_quant_dequant_kernel _triton_nvfp4_quant_dequant _triton_dequantize_nvfp4 test_triton_dequantize_nvfp4 test_triton_nvfp4_quant_dequant

关键源码片段

vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py data-contract

最主要变更文件,新增 6 个 Triton JIT 内核函数,修改 dequantize_to_dtype 入口以调度 Triton 路径,并引入 import 依赖。

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM projectimport torch
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.triton_utils import tl, triton# ... 其他导入和常量 ...@triton.jit
def _e2m1_inline(magnitude):
    """Inline E2M1 lookup using binary tree - 3 levels instead of 7 sequential.
    Maps 3-bit magnitude to float: [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
    """
    # Bit 2 (MSB): separates 0-3 from 4-7
    # Bit 1: separates within groups
    # Bit 0 (LSB): separates within pairs
    b2 = (magnitude >> 2) & 1
    b1 = (magnitude >> 1) & 1
    b0 = magnitude & 1
​
    # For mag 0-3: [0.0, 0.5, 1.0, 1.5]
    low_group = tl.where(
        b1 == 1, tl.where(b0 == 1, 1.5, 1.0), tl.where(b0 == 1, 0.5, 0.0))
    # For mag 4-7: [2.0, 3.0, 4.0, 6.0]
    high_group = tl.where(
        b1 == 1, tl.where(b0 == 1, 6.0, 4.0), tl.where(b0 == 1, 3.0, 2.0))
    return tl.where(b2 == 1, high_group, low_group)
​
​
@triton.jit
def _dequantize_nvfp4_kernel(
    fp4_ptr, scale_ptr, global_scale_ptr, output_ptr,
    rows_per_batch: tl.constexpr, num_blocks: tl.constexpr,
    BLOCK_SIZE: tl.constexpr, has_batch_global_scale: tl.constexpr,
    TILE_BLOCKS: tl.constexpr,
):
    """Triton kernel for NVFP4 dequantization (swizzle=False).
    Uses 2D tile processing + interleave for coalesced stores.
    """
    BLOCK_PACKED: tl.constexpr = BLOCK_SIZE // 2
    row_idx = tl.program_id(0)
    tile_idx = tl.program_id(1)
​
    if has_batch_global_scale:
        batch_idx = row_idx // rows_per_batch
        global_scale = tl.load(global_scale_ptr + batch_idx).to(tl.float32)
    else:
        global_scale = tl.load(global_scale_ptr).to(tl.float32)
​
    # ... ( 计算偏移,加载 scale,加载 packed bytes)
​
    # 提取低 4 位和高 4 位
    low_nibble = raw_bytes & 0x0F
    high_nibble = (raw_bytes >> 4) & 0x0F
​
    # 使用二进制树解码 E2M1
    low_mag = low_nibble & 0x07
    low_val = _e2m1_inline(low_mag)
    low_sign = (low_nibble >> 3) & 1
    low_result = tl.where(low_sign == 1, -low_val, low_val) * scale_values
​
    high_mag = high_nibble & 0x07
    high_val = _e2m1_inline(high_mag)
    high_sign = (high_nibble >> 3) & 1
    high_result = tl.where(high_sign == 1, -high_val, high_val) * scale_values
​
    # 交织存储以合并写操作
    result = tl.interleave(low_result, high_result)
    tl.store(output_ptr + out_indices, result, mask=block_mask[:, None])
​
​
# ... 其他内核和调度函数 ...# 在 dequantize_to_dtype 入口处添加调度逻辑:
def dequantize_to_dtype(
    tensor_fp4: torch.Tensor,
    tensor_sf: torch.Tensor,
    global_scale: torch.Tensor, # 类型从 Tensor | float 改为此
    dtype: torch.dtype,
    block_size: int = 16,
    swizzle: bool | None = True,
):
    # ... 前置处理 ...
    if current_platform.is_cuda_alike() and not swizzle:
        # 使用 Triton 内核加速
        return _triton_dequantize_nvfp4(
            tensor_fp4, tensor_sf, global_scale, dtype, block_size)
    # 否则回退到原有 PyTorch 实现
    # ...
tests/kernels/quantization/test_nvfp4_emulation.py test-coverage

新增测试文件,覆盖 2D/3D 多种形状和真实模型权重的正确性与性能基准,验证 Triton 内核与参考实现 bitwise 一致。

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM projectimport huggingface_hub
import pytest
import torch
from safetensors import safe_openfrom vllm.model_executor.layers.quantization.utils import nvfp4_emulation_utils
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
    dequantize_to_dtype, ref_nvfp4_quant_dequant)
from vllm.platforms import current_platform
from vllm.triton_utils import triton# 跳过非 CUDA 平台
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(),
    reason="Triton NVFP4 kernel requires CUDA.")
def test_triton_dequantize_nvfp4(monkeypatch):
    """Test Triton dequant kernel against CPU reference using real weights."""
    # 从 HuggingFace Hub 下载真实权重
    checkpoint_path = huggingface_hub.snapshot_download(
        "nvidia/Qwen3-30B-A3B-NVFP4",
        allow_patterns=["model-00001-of-00004.safetensors"])
    shard_path = f"{checkpoint_path}/model-00001-of-00004.safetensors"
    block_size = 16
​
    with safe_open(shard_path, framework="pt", device="cpu") as f:
        # 2D 张量:注意力投影
        tensor_fp4_2d = f.get_tensor("model.layers.9.self_attn.k_proj.weight")
        tensor_sf_2d = f.get_tensor("model.layers.9.self_attn.k_proj.weight_scale")
        global_scale_2d = f.get_tensor("model.layers.9.self_attn.k_proj.weight_scale_2")
​
        # 3D 张量:堆叠所有专家的 up_proj
        # ... ( 省略收集所有 expert 权重的代码 ) ...
​
    # 准备测试用例列表,涵盖 2D/3D 的不同 size
    test_cases = [
        ("2D base", tensor_fp4_2d, tensor_sf_2d, global_scale_2d),
        ("3D base", tensor_fp4_3d, tensor_sf_3d, global_scale_3d),
        # ... 更多形状 ...
    ]
​
    # 将 E2M1 LUT 搬到 CUDA(模拟模型加载时的行为)
    nvfp4_emulation_utils.kE2M1ToFloat_handle.val = (
        nvfp4_emulation_utils.kE2M1ToFloat_handle.val.cuda())
​
    for label, fp4, sf, gs in test_cases:
        fp4_cuda = fp4.cuda()
        sf_cuda = sf.cuda()
        gs_cuda = gs.cuda()
​
        # Triton 路径
        triton_result = dequantize_to_dtype(
            fp4_cuda, sf_cuda, gs_cuda, torch.bfloat16, block_size,
            swizzle=False)
​
        # 参考路径(通过 monkeypatch 禁用 CUDA 分支来强制 PyTorch 实现)
        with monkeypatch.context() as m:
            m.setattr(
                nvfp4_emulation_utils.current_platform, "is_cuda_alike",
                lambda: False)
            reference = dequantize_to_dtype(
                fp4_cuda, sf_cuda, gs_cuda, torch.bfloat16, block_size,
                swizzle=False)
​
        # 验证 bitwise 一致
        torch.testing.assert_close(triton_result, reference, atol=0, rtol=0)
​
    # 性能基准部分(由 _triton_bench 和 _reference_bench 辅助)
    # ...

评论区精华

类型安全性:global_scale 从支持 float 改为仅 Tensor 正确性

gemini-code-assist 指出类型收紧可能破坏传入 float 的调用者,且 3D 分支逻辑对 scalar tensor 可能不兼容。

结论:作者回应此改动从 #35737 移植而来,是安全的。PR 中未进一步放宽类型,但内部调用均为 Tensor,风险可控。 · 已解决

返回 tuple 的不必要性 设计

BowenBao 询问 `ref_nvfp4_quant_dequant` 为何返回 `tuple[torch.Tensor, None]`。

结论:作者在 commit 662fa67 中调整了调用处,不再直接返回 tuple;函数签名保留 tuple 以保持一致性。 · 已解决

E2M1 查找表设备迁移模式 设计

kylesayrs 认为在 EMULATION 分支提前搬迁 `kE2M1ToFloat_handle` 到设备是一种奇怪的模式,但理解这是避免 CUDA graph 限制的必要折中。

结论:作者同意当前模式最佳,计划后续 PR 中改进。讨论未要求立即更改。 · 已解决

建议添加 block_size 整除性断言 正确性

kylesayrs 建议在 `_triton_nvfp4_quant_dequant` 中添加 `x_k % block_size` 断言,确保形状兼容。

结论:作者在 7c885b37 中采纳并添加了该断言(仅在非编译时)。 · 已解决

未使用的函数参数 style

kylesayrs 指出某个参数未使用。

结论:作者在 7c885b37 中清理了未使用的参数(如 `e2m1_lut_ptr`)。 · 已解决

风险与影响

  1. 类型缩小风险nvfp4_emulation_utils.py):global_scale 参数类型从 torch.Tensor | float 改为 torch.Tensor,若外部调用者传入 float 将会触发 TypeError。当前代码库内部调用均使用 Tensor,但第三方扩展可能受影响。
  2. 边界条件风险_dequantize_nvfp4_kernel):内核假设 BLOCK_SIZE 整除张量维度,添加了断言但仅限非编译时;对于动态形状或非标准 block_size(未测试 32 等)可能出错。
  3. HSA 崩溃规避的附带风险:融合内核虽解决了大量小内核导致的崩溃,但若其他路径仍发射大量独立内核(如非 emulation 模式),问题依然存在。本 PR 仅覆盖 NVFP4 emulation 路径。
  4. CUDA graph 兼容性:在 oracle/nvfp4.py 中提前搬运查找表是必要的,但这种跨后端耦合可能引入未来重构时的遗漏。

用户影响:使用 NVFP4 量化模拟的用户(AMD MI3xx、H100)将获得 10-58 倍的反量化/QDQ 操作加速,端到端吞吐提升 6-8 倍。同时,因大量内核堆积导致的 HSA 崩溃不再发生。
系统影响:仅修改 NVFP4 emulation 路径,不涉及默认推理路径。对于使用本机 NVFP4 支持(如 H100 原生 FP4)的用户无影响。
团队影响:维护两个实现(PyTorch 参考 + Triton 内核),需要确保一致性。测试覆盖了关键形状和真实模型权重,但未必覆盖所有边缘形状。

类型提示变更 新内核边界条件 HSA 崩溃规避的局限性 CUDA graph 兼容性模式

关联 Issue

#180341 [rocm only] Segfault in AsyncEventsLoop when queuing many small kernels without synchronization
#5071 [Issue]: Segfault in AsyncEventsLoop when queuing many small kernels without synchronization ("fixed" with `HSA_TOOLS_DISABLE_REGISTER=1`)

完整报告

参与讨论