执行摘要
- 一句话:移除ROCm不支持的FP8测试用例并修复normalize条件
- 推荐动作:此PR为维护性清理,不值得精读。但可以关注ROCm FP8 AITER的支持边界以及fp8_utils中normalize条件的改进思路。
功能与动机
ROCm上AITER仅支持group量化,不支持per-tensor量化融合;另外test_fuse_act_padding存在已知精度问题(见ROCm/aiter#2614)。
实现拆解
- 在
vllm/model_executor/layers/quantization/utils/fp8_utils.py中的三个处理函数(process_fp8_weight_tensor_strategy、process_fp8_weight_channel_strategy、process_fp8_weight_block_strategy)的条件判断中增加weight.dtype == torch.float8_e4m3fn检查,确保只有在权重类型为float8_e4m3fn时才调用normalize_e4m3fn_to_e4m3fnuz,避免重复或错误转换。
- 在
tests/compile/passes/test_fuse_act_padding.py中,将@pytest.mark.skipif替换为@pytest.mark.skip,并关联上游issue,永久跳过该测试;同时移除is_aiter_found_and_supported导入和outputs_unfused = model(x)行。
- 在
tests/compile/passes/test_fusion.py中,从AITER_KERNEL_GROUPSHAPE_COMBINATIONS列表中移除(ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR, False)条目,因为AITER不支持per-tensor量化融合。
关键文件:
vllm/model_executor/layers/quantization/utils/fp8_utils.py(模块 量化工具;类别 source;类型 data-contract;符号 process_fp8_weight_tensor_strategy, process_fp8_weight_channel_strategy, process_fp8_weight_block_strategy): 核心源码变更,在三个FP8 weight处理函数中添加了类型检查条件,防止不必要的normalize操作。
tests/compile/passes/test_fuse_act_padding.py(模块 融合填充测试;类别 test;类型 test-coverage): 由于已知精度问题,整个测试被永久跳过;同时清理了相关导入和未使用代码。
tests/compile/passes/test_fusion.py(模块 FP8融合测试;类别 test;类型 test-coverage): 移除AITER不支持的per-tensor量化融合测试组合
关键符号:process_fp8_weight_tensor_strategy, process_fp8_weight_channel_strategy, process_fp8_weight_block_strategy
关键源码片段
vllm/model_executor/layers/quantization/utils/fp8_utils.py
核心源码变更,在三个FP8 weight处理函数中添加了类型检查条件,防止不必要的normalize操作。
def process_fp8_weight_tensor_strategy(
weight: torch.Tensor,
weight_scale: torch.Tensor,
logical_widths: list[int],
input_scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""Process weights for tensor-wise quantization strategy."""
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale,
)
# 仅当平台是 FP8 fnuz 且权重类型为 float8_e4m3fn 时,
# 才执行 normalize 操作,避免重复转换。
if current_platform.is_fp8_fnuz() and weight.dtype == torch.float8_e4m3fn:
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale, input_scale=input_scale
)
# Requantize with max scale
weight_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=weight_scale,
logical_widths=logical_widths,
)
weight = _maybe_pad_fp8_weight(weight)
return weight, weight_scale, input_scale
评论区精华
review中无实质争议。维护者yewentao256与AndreasKaratzas均批准更改。yewentao256表示问题已在主分支修复。charlifu在评论中说明因精度问题跳过了test_fuse_act_padding,并引用了上游issue。
风险与影响
- 风险:主要风险是测试覆盖减少:test_fuse_act_padding被完全跳过,可能导致未来相关融合pass的回归未被捕获。此外,fp8_utils.py的条件增强虽降低了错误转换风险,但可能掩藏其他类型权重的问题。整体风险低。
- 影响:影响范围限定于ROCm平台的FP8量化功能。测试变更使AITER测试更准确反映实际支持情况;源码修改避免了不必要的类型转换,对CUDA平台无影响。团队需注意测试覆盖率下降,并跟踪上游issue的修复进展。
- 风险标记:测试覆盖减少, 上游issue依赖
关联脉络
参与讨论