执行摘要
- 一句话:添加Triton内核加速NVFP4反量化和QDQ模拟
- 推荐动作:值得精读:
- 学习 Triton 内核优化技巧:二进制树 E2M1 查找、2D tile 批处理、interleave 合并写。
- 理解设备间功能兼容性处理:通过
current_platform.is_cuda_alike() 动态切换实现。
- 关注社区反馈中对类型安全的关注,建议合并后进一步放宽
global_scale 类型以支持 float。
功能与动机
NVFP4 量化模拟在 AMD MI3xx、H100 等不支持原生 FP4 的设备上需要反量化和 QDQ 操作。原始的 PyTorch 参考实现因发射大量独立内核导致 HSA 崩溃(见 issue #180341、#5071)。本 PR 通过提供融合的 Triton 内核来显著加速模拟并规避内核堆积问题。
实现拆解
-
在 nvfp4_emulation_utils.py 中添加 Triton JIT 内核:
_e2m1_inline:使用二进制树分解实现内联 E2M1 查找,避免 LUT 导致的 LDS bank 冲突。
_dequantize_nvfp4_kernel:处理反量化,支持 2D/3D 输入,使用 tile 批处理和 interleave 存储以合并写操作。
_nvfp4_quant_dequant_kernel:融合量化-反量化通道,匹配 Python cast_to_fp4 的阈值。
_triton_dequantize_nvfp4 和 _triton_nvfp4_quant_dequant:调度函数,根据输入动态选择内核参数。
-
修改 dequantize_to_dtype 入口:条件判断 current_platform.is_cuda_alike(),在 CUDA/ROCm 上自动切换到 Triton 内核,否则回退到原始 PyTorch 实现。同时调整 global_scale 类型注解为 torch.Tensor(之前允许 float),以配合内核调用约定。
-
调整 MoE 模拟路径:
- 在
fused_moe/oracle/nvfp4.py 中,将 E2M1 查找表提前搬运至设备(kE2M1ToFloat_handle.val.to(...)),避免 CUDA graph 捕获时调用 .to(device)。
- 在
fused_moe/utils.py 中,ref_nvfp4_quant_dequant 的返回值改为先解包再返回 (A, None),以兼容下游预期。
-
新增测试文件 tests/kernels/quantization/test_nvfp4_emulation.py:
- 使用真实 NVFP4 权重(
nvidia/Qwen3-30B-A3B-NVFP4)验证 2D/3D 多种形状的正确性(bitwise identical)。
- 通过 monkeypatch 禁用 Triton 路径来对比 PyTorch 参考实现。
- 内嵌性能基准测试,量化加速比。
关键文件:
vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py(模块 量化工具;类别 source;类型 data-contract;符号 _e2m1_inline, _dequantize_nvfp4_kernel, _e2m1_lookup, _round_to_fp4): 最主要变更文件,新增 6 个 Triton JIT 内核函数,修改 dequantize_to_dtype 入口以调度 Triton 路径,并引入 import 依赖。
tests/kernels/quantization/test_nvfp4_emulation.py(模块 测试;类别 test;类型 test-coverage;符号 test_triton_dequantize_nvfp4, _triton_bench, _reference_bench, test_triton_nvfp4_quant_dequant): 新增测试文件,覆盖 2D/3D 多种形状和真实模型权重的正确性与性能基准,验证 Triton 内核与参考实现 bitwise 一致。
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py(模块 MoE调度;类别 source;类型 data-contract): MoE NVFP4 后端配置:在 EMULATION 分支提前将 E2M1 查找表搬到设备,避免 CUDA graph 捕获时触发 .to(device) 导致错误。
vllm/model_executor/layers/fused_moe/utils.py(模块 MoE工具;类别 source;类型 data-contract): 修改 moe_kernel_quantize_input 中 emulation 路径的返回值处理:将 ref_nvfp4_quant_dequant 的返回值先赋值再返回 (A, None),以匹配其他分支的返回格式。
关键符号:_e2m1_inline, _dequantize_nvfp4_kernel, _e2m1_lookup, _round_to_fp4, _nvfp4_quant_dequant_kernel, _triton_nvfp4_quant_dequant, _triton_dequantize_nvfp4, test_triton_dequantize_nvfp4, test_triton_nvfp4_quant_dequant
关键源码片段
vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py
最主要变更文件,新增 6 个 Triton JIT 内核函数,修改 dequantize_to_dtype 入口以调度 Triton 路径,并引入 import 依赖。
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.triton_utils import tl, triton
# ... 其他导入和常量 ...
@triton.jit
def _e2m1_inline(magnitude):
"""Inline E2M1 lookup using binary tree - 3 levels instead of 7 sequential.
Maps 3-bit magnitude to float: [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
"""
# Bit 2 (MSB): separates 0-3 from 4-7
# Bit 1: separates within groups
# Bit 0 (LSB): separates within pairs
b2 = (magnitude >> 2) & 1
b1 = (magnitude >> 1) & 1
b0 = magnitude & 1
# For mag 0-3: [0.0, 0.5, 1.0, 1.5]
low_group = tl.where(
b1 == 1, tl.where(b0 == 1, 1.5, 1.0), tl.where(b0 == 1, 0.5, 0.0))
# For mag 4-7: [2.0, 3.0, 4.0, 6.0]
high_group = tl.where(
b1 == 1, tl.where(b0 == 1, 6.0, 4.0), tl.where(b0 == 1, 3.0, 2.0))
return tl.where(b2 == 1, high_group, low_group)
@triton.jit
def _dequantize_nvfp4_kernel(
fp4_ptr, scale_ptr, global_scale_ptr, output_ptr,
rows_per_batch: tl.constexpr, num_blocks: tl.constexpr,
BLOCK_SIZE: tl.constexpr, has_batch_global_scale: tl.constexpr,
TILE_BLOCKS: tl.constexpr,
):
"""Triton kernel for NVFP4 dequantization (swizzle=False).
Uses 2D tile processing + interleave for coalesced stores.
"""
BLOCK_PACKED: tl.constexpr = BLOCK_SIZE // 2
row_idx = tl.program_id(0)
tile_idx = tl.program_id(1)
if has_batch_global_scale:
batch_idx = row_idx // rows_per_batch
global_scale = tl.load(global_scale_ptr + batch_idx).to(tl.float32)
else:
global_scale = tl.load(global_scale_ptr).to(tl.float32)
# ... ( 计算偏移,加载 scale,加载 packed bytes)
# 提取低 4 位和高 4 位
low_nibble = raw_bytes & 0x0F
high_nibble = (raw_bytes >> 4) & 0x0F
# 使用二进制树解码 E2M1
low_mag = low_nibble & 0x07
low_val = _e2m1_inline(low_mag)
low_sign = (low_nibble >> 3) & 1
low_result = tl.where(low_sign == 1, -low_val, low_val) * scale_values
high_mag = high_nibble & 0x07
high_val = _e2m1_inline(high_mag)
high_sign = (high_nibble >> 3) & 1
high_result = tl.where(high_sign == 1, -high_val, high_val) * scale_values
# 交织存储以合并写操作
result = tl.interleave(low_result, high_result)
tl.store(output_ptr + out_indices, result, mask=block_mask[:, None])
# ... 其他内核和调度函数 ...
# 在 dequantize_to_dtype 入口处添加调度逻辑:
def dequantize_to_dtype(
tensor_fp4: torch.Tensor,
tensor_sf: torch.Tensor,
global_scale: torch.Tensor, # 类型从 Tensor | float 改为此
dtype: torch.dtype,
block_size: int = 16,
swizzle: bool | None = True,
):
# ... 前置处理 ...
if current_platform.is_cuda_alike() and not swizzle:
# 使用 Triton 内核加速
return _triton_dequantize_nvfp4(
tensor_fp4, tensor_sf, global_scale, dtype, block_size)
# 否则回退到原有 PyTorch 实现
# ...
tests/kernels/quantization/test_nvfp4_emulation.py
新增测试文件,覆盖 2D/3D 多种形状和真实模型权重的正确性与性能基准,验证 Triton 内核与参考实现 bitwise 一致。
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import huggingface_hub
import pytest
import torch
from safetensors import safe_open
from vllm.model_executor.layers.quantization.utils import nvfp4_emulation_utils
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
dequantize_to_dtype, ref_nvfp4_quant_dequant)
from vllm.platforms import current_platform
from vllm.triton_utils import triton
# 跳过非 CUDA 平台
@pytest.mark.skipif(
not current_platform.is_cuda_alike(),
reason="Triton NVFP4 kernel requires CUDA.")
def test_triton_dequantize_nvfp4(monkeypatch):
"""Test Triton dequant kernel against CPU reference using real weights."""
# 从 HuggingFace Hub 下载真实权重
checkpoint_path = huggingface_hub.snapshot_download(
"nvidia/Qwen3-30B-A3B-NVFP4",
allow_patterns=["model-00001-of-00004.safetensors"])
shard_path = f"{checkpoint_path}/model-00001-of-00004.safetensors"
block_size = 16
with safe_open(shard_path, framework="pt", device="cpu") as f:
# 2D 张量:注意力投影
tensor_fp4_2d = f.get_tensor("model.layers.9.self_attn.k_proj.weight")
tensor_sf_2d = f.get_tensor("model.layers.9.self_attn.k_proj.weight_scale")
global_scale_2d = f.get_tensor("model.layers.9.self_attn.k_proj.weight_scale_2")
# 3D 张量:堆叠所有专家的 up_proj
# ... ( 省略收集所有 expert 权重的代码 ) ...
# 准备测试用例列表,涵盖 2D/3D 的不同 size
test_cases = [
("2D base", tensor_fp4_2d, tensor_sf_2d, global_scale_2d),
("3D base", tensor_fp4_3d, tensor_sf_3d, global_scale_3d),
# ... 更多形状 ...
]
# 将 E2M1 LUT 搬到 CUDA(模拟模型加载时的行为)
nvfp4_emulation_utils.kE2M1ToFloat_handle.val = (
nvfp4_emulation_utils.kE2M1ToFloat_handle.val.cuda())
for label, fp4, sf, gs in test_cases:
fp4_cuda = fp4.cuda()
sf_cuda = sf.cuda()
gs_cuda = gs.cuda()
# Triton 路径
triton_result = dequantize_to_dtype(
fp4_cuda, sf_cuda, gs_cuda, torch.bfloat16, block_size,
swizzle=False)
# 参考路径(通过 monkeypatch 禁用 CUDA 分支来强制 PyTorch 实现)
with monkeypatch.context() as m:
m.setattr(
nvfp4_emulation_utils.current_platform, "is_cuda_alike",
lambda: False)
reference = dequantize_to_dtype(
fp4_cuda, sf_cuda, gs_cuda, torch.bfloat16, block_size,
swizzle=False)
# 验证 bitwise 一致
torch.testing.assert_close(triton_result, reference, atol=0, rtol=0)
# 性能基准部分(由 _triton_bench 和 _reference_bench 辅助)
# ...
评论区精华
风险与影响
关联脉络
- PR #35733 相关的 NVFP4 模拟基础设施 PR(未在历史列表中,但 PR body 提及): 本 PR 建立在 #35733 的基础上,提供 Triton 内核加速。
- PR #35737 相关的 NVFP4 模拟 PR(未在历史列表中,但 PR body 提及): 本 PR 独立于 #35737 提交,但密切相关,属于同一功能线。
- PR #180341 ROCm 下大量小内核导致段错误: 本 PR 通过融合内核解决了该 issue 中的崩溃问题。
- PR #5071 ROCm 下 AsyncEventsLoop 段错误: 本 PR 规避了该 issue 中描述的 HSA 崩溃。
参与讨论