Prhub

#40552 [Bugfix] Fix RMS norm + quant fusion on DeepGEMM UE8M0 path for B200

原始 PR 作者 Lucaskabela 合并时间 2026-04-23 06:04 文件变更 2 提交数 2 评论 8 代码增减 +22 / -0

执行摘要

跳过 B200 上 DeepGEMM UE8M0 路径的 RMS+quant 融合测试

修复 B200 上 torch vllm 联合测试中的 8 个失败用例(test_fusion_rmsnorm_quant),根本原因是 QuantFP8 在 B200 上走 packed UE8M0 路径(per_token_group_fp8_quant_packed,int32-packed scales),但 rms+quant 融合模式只匹配 fp32-scale 变体,导致断言失败。

建议合并,因为这是临时性的测试跳过,且文档清晰地指出了根本原因和修复方向。审阅者应关注后续是否有人跟进实现真正的融合修复(可追踪 TODO 和 draft PR #40650)。

讨论亮点
  • ElizaWszola 询问性能影响,建议 benchmark 对比 fused 和 packed 版本。提交者 Lucaskabela 提供了详细的 micro-benchmark 数据(见表),显示 fused 版本在大多数配置下快 10-22%,但在大 batch 下 packed 更快。
  • ProExpertProg 同意跳过测试的方案,但要求将原始修复代码保存为 draft PR 以供后续参考。
  • gemini-code-assist[bot] 建议将 layernorm_utils.cuh 中的 scale 调整逻辑提取为共享 helper 函数,并用命名常量替代魔数 1e-10f。

实现拆解

  1. 新增 import:在 tests/compile/passes/test_fusion.py 中导入 is_deep_gemm_e8m0_used 工具函数。
  2. 添加跳过条件:在 test_fusion_rmsnorm_quant 测试函数中,当 dtype 为 bf16、kernel 为 DeepGemmFp8BlockScaledMMKernel 或 FlashInferFp8DeepGEMMDynamicBlockScaledKernel、且 is_deep_gemm_e8m0_used() 返回 True 时,跳过该测试用例,并附上详细 TODO 说明。
  3. 补充 block_size 属性:在 tests/utils.pyTestFp8Linear 类初始化中,为 block-wise 路径添加 self.weight_block_size = [block_size, block_size],确保测试辅助类与真实模型行为一致。
  4. 移除原始融合修复代码:第一个提交曾尝试实现在 rms_norm_per_block_quant 中处理 UE8M0 scale,但 review 讨论后决定暂不引入此修复,仅跳过测试。
文件 模块 状态 重要度
tests/compile/passes/test_fusion.py 编译融合 modified 5.48
tests/utils.py 测试工具 modified 3.28

关键符号

test_fusion_rmsnorm_quant TestFp8Linear.__init__

关键源码片段

tests/compile/passes/test_fusion.py test-coverage

主要变更文件,添加了跳过条件逻辑和详细 TODO 注释,解释为什么需要跳过及后续修复方向。

# 在测试函数中添加跳过条件:当使用 DeepGEMM UE8M0 路径时跳过
# TODO(quant-rms-fusion): DeepGEMM UE8M0 activation quant on B200 lowers
# to a packed int32-scale op (per_token_group_quant_fp8_packed_for_deepgemm),
# but the rms+quant fusion pattern only matches the fp32-scale variant, so
# the fused output gets a mismatched scale layout and produces NaN. Only
# reproduces on bf16 (DeepGEMM UE8M0 on B200 is bf16-only).
# To re-enable: make rms_norm_per_block_quant emit packed UE8M0 scales
# and extend the fusion pattern to rewrite the packed activation quant.
deepgemm_kernels = (
    DeepGemmFp8BlockScaledMMKernel,
    FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
)
if (
    dtype == torch.bfloat16
    and force_kernel in deepgemm_kernels
    and is_deep_gemm_e8m0_used()
):
    pytest.skip(
        "rms+quant fusion does not yet match the packed UE8M0 DeepGEMM path"
    )
tests/utils.py test-coverage

补充了 block-wise 量化所需的 weight_block_size 属性,确保测试辅助类与真实模型行为一致。

# 在 block-wise 分支中增加 weight_block_size 属性
if is_block_wise:
    block_size = weight_scale_desc.group_shape.col
    weight_scale_shape = weight_shape[0] // block_size
    self.weight_scale_inv = torch.rand(
        (weight_scale_shape, weight_scale_shape), dtype=torch.float32
    )
    self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
    self.input_scale = None
    self.weight_scale = None
    self.weight_block_size = [block_size, block_size] # 新增:记录 block size
    if transpose_weights:
        self.weight = self.weight.t()

评论区精华

性能影响评估 性能

ElizaWszola 质疑 fused 和 packed 版本的性能差异,要求提供 benchmark 数据。

结论:Lucaskabela 提供了 micro-benchmark 表,显示 fused 在多数配置下更快(10-22%),但在大 batch(3072 tokens, 7168 hidden)下 packed 快 18%。同意跳过测试。 · 已解决

代码重复与可维护性 style

gemini-code-assist[bot] 指出 layernorm_utils.cuh 中的 scale 调整逻辑与 vectorized 版本重复,建议提取为 helper 函数并命名魔数。

结论:未直接处理,因为该修复代码已被移除,仅保留测试跳过。 · outdated

后续修复计划 设计

ProExpertProg 要求将原始修复代码保存为 draft PR,以便后续实现真正的融合。

结论:已创建 draft PR #40650,后续可在此基础上完善。 · 已解决

风险与影响

本 PR 仅跳过测试,未修改生产代码,风险极低。但需要注意:跳过测试意味着 B200 上 DeepGEMM UE8M0 路径的 rms+quant 融合未经验证,可能存在隐藏的正确性问题。此外,tests/utils.py 中添加 weight_block_size 属性可能影响其他依赖此类的测试,但该属性仅为新增字段,不会破坏现有逻辑。

用户:无直接影响,因为这是测试级别的变更。系统:B200 上相关测试不会再因预期失败而中断 CI。团队:需在后续 PR 中实现真正的融合修复(见 TODO),避免长期跳过测试导致回归遗漏。影响程度低,范围仅限于特定硬件(B200)的特定测试。

测试覆盖跳过 后续需修复

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论