Prhub

#39799 [ROCm][CI] Fix TestSiluMulGroupFp8QuantModel after W8A8 block linear refactor

原始 PR 作者 AndreasKaratzas 合并时间 2026-04-25 10:21 文件变更 1 提交数 6 评论 0 代码增减 +19 / -10

执行摘要

修复 ROCm 上 SiluMul+FP8 融合测试因重构而中断

W8A8 Block Linear 重构(PR#33892)将 W8A8BlockFp8LinearOp 替换为 TestFP8Layer,但 TestSiluMulGroupFp8QuantModel 的 forward 调用和操作列表未同步更新,导致 ROCm 平台上的测试失败。此外,在非 fnuz ROCm(如 MI355)上,融合模式期望 Triton 分组量化操作,但小测试形状下 use_triton 标志为 False,需强制启用。

值得精读,特别是了解在重构后如何联动调整测试代码的实践。关注点:平台区分(fnuz vs 非 fnuz)、猴子补丁技巧、操作列表与编译传递的对应关系。

讨论亮点

该 PR 没有实质性的人工 review 讨论。自动机器人 gemini-code-assist[bot] 提供了摘要性评论,但无反馈意见。审核者 tjtanaa 直接批准(LGTM)。

实现拆解

  1. 移除冗余权重/缩放张量:在 TestSiluMulGroupFp8QuantModel.__init__ 中删除了 self.wself.wscale,因为 TestFP8Layer 在内部创建权重。
  2. 强制 Triton 量化路径:在非 fnuz ROCm 平台上,通过猴子补丁(monkey-patch)强制 kernel.quant_fp8use_triton=True,确保融合测试使用预期内核。
  3. 更新操作列表:在 ops_in_model_before 中根据平台动态返回 rocm_aiter_ops.get_group_quant_op()(fnuz)或 torch.ops.vllm.triton_per_token_group_quant_fp8.default(非 fnuz),保证编译传递能够识别正确操作。
  4. 调整容差阈值:为 TestSiluMulBlockQuantModel 在 ROCm 上设置更严格的容差(1e-3),同时 CUDA 上保持宽松(5e-2)以包容浮点计算差异。
  5. 导入调整:将 rocm_aiter_ops 的导入从函数体内提升到文件顶部,避免重复导入。
文件 模块 状态 重要度
tests/compile/passes/test_silu_mul_quant_fusion.py 编译测试 modified 6.16

关键源码片段

tests/compile/passes/test_silu_mul_quant_fusion.py test-coverage

唯一修改的文件,修复了 SiluMul+FP8 融合测试的所有三个问题:forward 调用、操作列表、Triton 量化路径。

# tests/compile/passes/test_silu_mul_quant_fusion.py
# 在 TestSiluMulGroupFp8QuantModel.__init__ 中,猴子补丁强制 Triton 路径
if not current_platform.is_fp8_fnuz():
    kernel = self.w8a8_block_fp8_linear.kernel
    orig_quant = kernel.quant_fp8
    # 将所有 quant_fp8 调用强制使用 use_triton=True
    kernel.quant_fp8 = lambda *a, use_triton=False, **kw: orig_quant(
        *a, use_triton=True, **kw
    )# ops_in_model_before 根据平台动态返回量化操作
# 对于 fnuz ROCm 使用 AITER 操作,否则使用 Triton 操作
# 这是因为融合模式需要精确匹配预期操作列表
def ops_in_model_before(self):
    return [
        SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
        rocm_aiter_ops.get_group_quant_op()
        if current_platform.is_fp8_fnuz()
        else torch.ops.vllm.triton_per_token_group_quant_fp8.default,
    ]# 在 test_fusion_silu_and_mul_quant 中区分模型类型调整容差
# ROCm 上 BlockQuant 模型使用更严格的 1e-3,CUDA 上保持 5e-2
elif isinstance(model, TestSiluMulBlockQuantModel):
    if current_platform.is_rocm():
        atol, rtol = 1e-3, 1e-3
    else:
        atol, rtol = 5e-2, 5e-2

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

低风险。变更仅限于测试文件 tests/compile/passes/test_silu_mul_quant_fusion.py,不涉及生产代码。主要风险在于:

  • 如果其他测试模式或实际生产路径也依赖于类似的 TestFP8Layer 调用模式,可能因未同步更新而失败。但本 PR 已针对特定模型修复。
  • 强制 use_triton=True 可能在小 shape 下引入 Triton 代码路径,但测试范围有限,不易产生副作用。

正面影响:恢复了 ROCm 平台上 SiluMulFP8 融合测试的正确性,确保编译器融合传递在 AMD GPU 上按预期工作。
影响范围:仅影响测试文件,用户不受直接影响。对开发团队,维护了 CI 的稳定性。

仅测试文件变更

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论