执行摘要
- 一句话:跳过 B200 上 DeepGEMM UE8M0 路径的 RMS+quant 融合测试
- 推荐动作:建议合并,因为这是临时性的测试跳过,且文档清晰地指出了根本原因和修复方向。审阅者应关注后续是否有人跟进实现真正的融合修复(可追踪 TODO 和 draft PR #40650)。
功能与动机
修复 B200 上 torch vllm 联合测试中的 8 个失败用例(test_fusion_rmsnorm_quant),根本原因是 QuantFP8 在 B200 上走 packed UE8M0 路径(per_token_group_fp8_quant_packed,int32-packed scales),但 rms+quant 融合模式只匹配 fp32-scale 变体,导致断言失败。
实现拆解
- 新增 import:在
tests/compile/passes/test_fusion.py 中导入 is_deep_gemm_e8m0_used 工具函数。
- 添加跳过条件:在
test_fusion_rmsnorm_quant 测试函数中,当 dtype 为 bf16、kernel 为 DeepGemmFp8BlockScaledMMKernel 或 FlashInferFp8DeepGEMMDynamicBlockScaledKernel、且 is_deep_gemm_e8m0_used() 返回 True 时,跳过该测试用例,并附上详细 TODO 说明。
- 补充 block_size 属性:在
tests/utils.py 的 TestFp8Linear 类初始化中,为 block-wise 路径添加 self.weight_block_size = [block_size, block_size],确保测试辅助类与真实模型行为一致。
- 移除原始融合修复代码:第一个提交曾尝试实现在
rms_norm_per_block_quant 中处理 UE8M0 scale,但 review 讨论后决定暂不引入此修复,仅跳过测试。
关键文件:
tests/compile/passes/test_fusion.py(模块 编译融合;类别 test;类型 test-coverage): 主要变更文件,添加了跳过条件逻辑和详细 TODO 注释,解释为什么需要跳过及后续修复方向。
tests/utils.py(模块 测试工具;类别 test;类型 test-coverage): 补充了 block-wise 量化所需的 weight_block_size 属性,确保测试辅助类与真实模型行为一致。
关键符号:test_fusion_rmsnorm_quant, TestFp8Linear.init
关键源码片段
tests/compile/passes/test_fusion.py
主要变更文件,添加了跳过条件逻辑和详细 TODO 注释,解释为什么需要跳过及后续修复方向。
# 在测试函数中添加跳过条件:当使用 DeepGEMM UE8M0 路径时跳过
# TODO(quant-rms-fusion): DeepGEMM UE8M0 activation quant on B200 lowers
# to a packed int32-scale op (per_token_group_quant_fp8_packed_for_deepgemm),
# but the rms+quant fusion pattern only matches the fp32-scale variant, so
# the fused output gets a mismatched scale layout and produces NaN. Only
# reproduces on bf16 (DeepGEMM UE8M0 on B200 is bf16-only).
# To re-enable: make rms_norm_per_block_quant emit packed UE8M0 scales
# and extend the fusion pattern to rewrite the packed activation quant.
deepgemm_kernels = (
DeepGemmFp8BlockScaledMMKernel,
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
)
if (
dtype == torch.bfloat16
and force_kernel in deepgemm_kernels
and is_deep_gemm_e8m0_used()
):
pytest.skip(
"rms+quant fusion does not yet match the packed UE8M0 DeepGEMM path"
)
tests/utils.py
补充了 block-wise 量化所需的 weight_block_size 属性,确保测试辅助类与真实模型行为一致。
# 在 block-wise 分支中增加 weight_block_size 属性
if is_block_wise:
block_size = weight_scale_desc.group_shape.col
weight_scale_shape = weight_shape[0] // block_size
self.weight_scale_inv = torch.rand(
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
)
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
self.input_scale = None
self.weight_scale = None
self.weight_block_size = [block_size, block_size] # 新增:记录 block size
if transpose_weights:
self.weight = self.weight.t()
评论区精华
- ElizaWszola 询问性能影响,建议 benchmark 对比 fused 和 packed 版本。提交者 Lucaskabela 提供了详细的 micro-benchmark 数据(见表),显示 fused 版本在大多数配置下快 10-22%,但在大 batch 下 packed 更快。
- ProExpertProg 同意跳过测试的方案,但要求将原始修复代码保存为 draft PR 以供后续参考。
-
gemini-code-assist[bot] 建议将 layernorm_utils.cuh 中的 scale 调整逻辑提取为共享 helper 函数,并用命名常量替代魔数 1e-10f。
-
性能影响评估 (performance): Lucaskabela 提供了 micro-benchmark 表,显示 fused 在多数配置下更快(10-22%),但在大 batch(3072 tokens, 7168 hidden)下 packed 快 18%。同意跳过测试。
- 代码重复与可维护性 (style): 未直接处理,因为该修复代码已被移除,仅保留测试跳过。
- 后续修复计划 (design): 已创建 draft PR #40650,后续可在此基础上完善。
风险与影响
- 风险:本 PR 仅跳过测试,未修改生产代码,风险极低。但需要注意:跳过测试意味着 B200 上 DeepGEMM UE8M0 路径的 rms+quant 融合未经验证,可能存在隐藏的正确性问题。此外,
tests/utils.py 中添加 weight_block_size 属性可能影响其他依赖此类的测试,但该属性仅为新增字段,不会破坏现有逻辑。
- 影响:用户:无直接影响,因为这是测试级别的变更。系统:B200 上相关测试不会再因预期失败而中断 CI。团队:需在后续 PR 中实现真正的融合修复(见 TODO),避免长期跳过测试导致回归遗漏。影响程度低,范围仅限于特定硬件(B200)的特定测试。
- 风险标记:测试覆盖跳过, 后续需修复
关联脉络
- PR #40650 Draft: Fix RMS norm + quant fusion on DeepGEMM UE8M0 path for B200: 同一功能的真正修复草案,本 PR 仅跳过测试,该 draft PR 包含实际修复代码。
参与讨论