Prhub

#41825 [ROCm][Perf] Fix RMSNorm+Quant fusion for gfx950 (non-fnuz)

原始 PR 作者 frida-andersson 合并时间 2026-05-12 03:00 文件变更 3 提交数 9 评论 30 代码增减 +342 / -7

执行摘要

修复 gfx950 上 RMSNorm+FP8 融合,延迟降 3.3%

PR描述指出,gfx950(non-fnuz)上AITER RMSNorm+GroupedQuantFP8融合内核被静默跳过,导致性能未达预期。两个根因:一是MatcherQuantFP8在非fnuz平台错误选择了triton_per_token_group_quant_fp8而非rocm_aiter_group_fp8_quant;二是DSv3.2的FX图中RMSNorm输出被多个量化操作共享,违反1-to-1模式匹配约束。

值得精读该PR,尤其是matcher_utils.py的修正和DoubleAiterRMSFp8GroupQuantPattern的声明式模式实现。它展示了从手动FX图变换到声明式模式匹配的演进思路,以及view-tolerant变体处理实际生产图中常见噪声的经验。设计决策(重复rms_norm而非保留未融合的16位读取)也有借鉴意义。建议在撰写自定义编译pass时参考此模式。

讨论亮点

Rohan138确认matcher_utils.py的更改“good catch, LGTM”。
Rohan138询问duplicate quant的来源,frida-andersson通过VLLM_DEBUG_DUMP_PATH提供证据,确认DSv3.2特有。
tjtanaa要求将图变换门控到gfx950(“IMPORTANT NOTE: Do not import anything from vllm.platforms.rocm without guarding it with current_platform.is_rocm()”),并要求将logger.info降为logger.debug
ProExpertProg建议用声明式DoubleQuant模式替代手动图变换(“could we just add a new pattern -> replacement”),避免复杂且脆弱的手写变换。
ChuanLi1101实施模式并添加view-tolerant变体,ProExpertProg最终批准(“this is in fact much cleaner! Good work”)。
关于UUID一致性,tjtanaa担心移除clone_elimination.uuid()影响,ProExpertProg澄清“removes it from the pass key”,不影响pass实际执行。

实现拆解

  1. 修正算子选择matcher_utils.py):移除MatcherQuantFP8is_fp8_fnuz()条件判断,当match_rocm_aiter=True时始终使用rocm_aiter_ops.get_group_quant_op(),确保gfx950匹配正确算子。
  2. 新增DoubleQuant模式rocm_aiter_fusion.py):定义DoubleAiterRMSFp8GroupQuantPattern类,匹配一个rms_norm输出被两个不同group_fp8_quant消费的fan-out图,替换为两个独立融合的rms_norm_group_fp8_quant操作。
  3. 添加view-tolerant变体rocm_aiter_fusion.py):DoubleAiterRMSFp8GroupQuantViewPattern匹配中间有view/reshape的fan-out图(DSv3.2 MLA indexer特有形状),通过torch._inductor.fx_passes.post_grad.view_to_reshape将view转化为reshape,再匹配模式。
  4. 调整UUID缓存键pass_manager.py):恢复clone_elimination在非gfx950上的UUID参与缓存键,仅在gfx950上pop避免缓存未命中。
  5. 添加单元测试tests/compile/passes/test_double_aiter_rms_quant_fusion.py):用_NoViewDoubleQuantModel_ViewDoubleQuantModel两个模型分别测试无view和有view场景,参数化运行RocmAiterRMSNormQuantFusionPass并验证融合后的图中出现rocm_aiter_rmsnorm_fp8_group_quant节点。
文件 模块 状态 重要度
vllm/compilation/passes/fusion/rocm_aiter_fusion.py 编译融合 modified 8.72
vllm/compilation/passes/fusion/matcher_utils.py 编译匹配 modified 5.8
tests/compile/passes/test_double_aiter_rms_quant_fusion.py 编译测试 added 7.91

关键符号

DoubleAiterRMSFp8GroupQuantPattern DoubleAiterRMSFp8GroupQuantPattern.register DoubleAiterRMSFp8GroupQuantViewPattern DoubleAiterRMSFp8GroupQuantViewPattern.trace_with_view_to_reshape MatcherQuantFP8.__init__ test_double_aiter_rms_fp8_group_quant_fusion

关键源码片段

vllm/compilation/passes/fusion/rocm_aiter_fusion.py core-logic

核心变更文件,新增 DoubleAiterRMSFp8GroupQuantPattern 和 DoubleAiterRMSFp8GroupQuantViewPattern 两个模式类,分别处理无 view 和有 view 的 1-to-2 fan-out 融合场景。同时优化了 pass 的 uuid 生成。是整个 PR 的核心实现。

class DoubleAiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
    """
    Pattern matching ``rms_norm`` whose output feeds *two* distinct
    ``rocm_aiter_group_fp8_quant`` consumers, replacing it with two
    independent fused ``rms_norm_group_fp8_quant`` ops.    Repeating the rms_norm in the replacement is preferable to leaving
    the fused 16-bit rms output materialized for two unfused quant
    consumers, and matches what the previous manual graph surgery
    achieved by cloning the rms_norm node.
    """
​
    FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_fused_quant_op()
​
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape,
        match_aiter_quant: bool = True,
        symmetric: bool = True,
    ) -> None:
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        super().__init__(epsilon, key, match_aiter_quant)
​
    def register(self, pm_pass: PatternMatcherPass) -> None:
        # 定义 pattern:一个 rms_norm 输出连接到两个相同的量化操作
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
            result_rms = torch.ops.vllm_ir.rms_norm(input, weight, self.epsilon)
            result1, scale1 = self.quant_matcher(result_rms)
            result2, scale2 = self.quant_matcher(result_rms)
            return result1, scale1, result2, scale2
​
        # 定义 replacement:用两个独立的 fused op 替代
        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
            at1 = self.FUSED_OP(
                x=input, weight=weight, variance_epsilon=self.epsilon, group_size=128,
            )
            at2 = self.FUSED_OP(
                x=input, weight=weight, variance_epsilon=self.epsilon, group_size=128,
            )
            return at1[0], at1[1], at2[0], at2[1]
​
        pm.register_replacement(
            pattern, replacement,
            [self.empty(5, 16), self.empty(16)], # 示例输入
            pm.fwd_only, pm_pass,
        )
class DoubleAiterRMSFp8GroupQuantViewPattern(AiterRMSNormQuantPattern):
    """
    View-tolerant variant that matches the same fan-out but with a
    ``view``/``reshape`` between the ``rms_norm`` output and the two
    ``rocm_aiter_group_fp8_quant`` consumers.    This shape arises in DeepSeek-V3.2's MLA indexer q_c norm, where
    ``Fp8BlockScaledMMLinearKernel.apply_weights`` inserts a 2D-flatten
    view before each quant op.
    """
    ...
    @staticmethod
    def trace_with_view_to_reshape(graph: fx.Graph) -> None:
        # 将图中的 view 节点转换为 reshape,使 pattern 能匹配
        view_to_reshape(graph, skip_constructors=True)
        # 这里还可以处理连续 reshape 的折叠
tests/compile/passes/test_double_aiter_rms_quant_fusion.py test-coverage

新增单元测试,覆盖无 view 和有 view 两种 fan-out 形状,通过参数化模型验证 DoubleQuant 模式正确触发融合,提供回归保护。

class _NoViewDoubleQuantModel(torch.nn.Module):
    """``rms_norm -> 2x group_fp8_quant`` fan-out (Kimi-K2.5 / DSR1 shape)."""
    def __init__(self) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(HIDDEN_SIZE, dtype=torch.bfloat16))
​
    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        x = torch.relu(x)
        rms = torch.ops.vllm_ir.rms_norm(x, self.weight, EPS)
        q1, s1 = torch.ops.vllm.rocm_aiter_group_fp8_quant.default(rms, GROUP_SIZE)
        q2, s2 = torch.ops.vllm.rocm_aiter_group_fp8_quant.default(rms, GROUP_SIZE)
        return q1, s1, q2, s2
​
​
class _ViewDoubleQuantModel(torch.nn.Module):
    """``rms_norm -> view -> 2x group_fp8_quant`` fan-out (DSv3.2 shape)."""
    def __init__(self) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(HIDDEN_SIZE, dtype=torch.bfloat16))
​
    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        x = torch.relu(x)
        rms = torch.ops.vllm_ir.rms_norm(x, self.weight, EPS)
        view = rms.view(-1, rms.shape[-1])
        q1, s1 = torch.ops.vllm.rocm_aiter_group_fp8_quant.default(view, GROUP_SIZE)
        q2, s2 = torch.ops.vllm.rocm_aiter_group_fp8_quant.default(view, GROUP_SIZE)
        return q1, s1, q2, s2
​
​
@pytest.mark.parametrize("model_cls", [_NoViewDoubleQuantModel, _ViewDoubleQuantModel], ids=["no_view", "with_view"])
@pytest.mark.skipif(not is_aiter_found_and_supported(), reason="Only test on ROCm with AITER installed and supported")
def test_double_aiter_rms_fp8_group_quant_fusion(model_cls: type[torch.nn.Module], monkeypatch: pytest.MonkeyPatch) -> None:
    """
    Both fan-out shapes must fuse into ``rocm_aiter_rmsnorm_fp8_group_quant``.
    Failure on the ``with_view`` parametrization is a regression on the
    DSv3.2 q_c norm path that this PR's view-tolerant pattern is intended to cover.
    """
    torch._dynamo.reset()
    vllm_config = VllmConfig(
        model_config=ModelConfig(dtype=torch.bfloat16),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=["+rms_norm", "+quant_fp8"],
            pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True),
        ),
    )
    with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
        from vllm.compilation.passes.fusion.rocm_aiter_fusion import RocmAiterRMSNormQuantFusionPass
        torch.set_default_device("cuda")
        torch.set_default_dtype(torch.bfloat16)
        torch.manual_seed(0)
        m.setenv("VLLM_ROCM_USE_AITER", "1")
        rocm_aiter_ops.refresh_env_variables()
        fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config)
        passes = [NoOpEliminationPass(vllm_config), fusion_pass, PostCleanupPass(vllm_config)]
        backend = TestBackend(passes=passes)
        model = model_cls().eval()
        x = torch.randn(10, HIDDEN_SIZE, dtype=torch.bfloat16, device="cuda")
        result = torch.compile(model, backend=backend, fullgraph=True)(x)
        # 验证 fused 图形中包含正确的融合节点
        # ( 测试后端会检查至少一个 pattern 被替换 )

评论区精华

gfx950 门控 (gating) 设计

tjtanaa 要求手动图变换 `_dedup_and_duplicate_for_fusion` 必须仅限于 gfx950,避免影响其他 ROCm 平台 (gfx942/MI300 等 )。同时要求导入 `on_gfx950` 时需先检查 `current_platform.is_rocm()`。

结论:ChuanLi1101 在 commit 90474f7 中实施门控,仅在 `on_gfx950()` 下执行图变换,且 `on_gfx950` 的导入被包裹在 `is_rocm()` 检查内。后续模式版本保留相同门控。 · 已解决

DoubleQuant pattern 替代手动图变换 设计

ProExpertProg 建议用声明式 DoubleQuant 模式替代复杂且脆弱的手动 `_dedup_and_duplicate_for_fusion` 变换,认为重复 rms_norm 比保留未融合的 16 位读取更好。tjtanaa 同意并提议先合并再跟随。

结论:ChuanLi1101 在 commit 51502209 中移除手动变换,改用 `DoubleAiterRMSFp8GroupQuantPattern` 声明式模式。ProExpertProg 批准。 · 已解决

UUID 缓存键一致性 正确性

tjtanaa 询问为何移除 `clone_elimination.uuid()` 以及是否会影响缓存正确性。ProExpertProg 解释只影响缓存键,不影响 pass 实际执行。ChuanLi1101 在 commit 317a9eb 中恢复 UUID 追加,仅在 gfx950 上 pop。

结论:在非 gfx950 平台保持原缓存键,gfx950 上移除 clone_elimination 的 UUID 以避免缓存未命中。 · 已解决

日志级别调整 style

tjtanaa 要求将 `Pre-fusion: deduped...` 和 `Replaced X patterns` 的日志从 `logger.info` 降为 `logger.debug`,避免日志过多。

结论:ChuanLi1101 在 commit 317a9eb 中改为 `logger.debug`。 · 已解决

view-tolerant 模式验证 测试

akii96 确认 view-tolerant `DoubleAiterRMSFp8GroupQuantViewPattern` 在 DSv3.2 q_c norm fan-out 上成功触发,并提供了 BEFORE/AFTER FX 图证据。

结论:模式正确匹配,性能提升确认。 · 已解决

风险与影响

  1. matcher_utils.py全局影响:移除is_fp8_fnuz()分支仅影响match_rocm_aiter=True路径,但仍可能在非gfx950的ROCm平台上改变量化算子选择,需确保所有match_rocm_aiter用例(gfx942等)正确调用get_group_quant_op()
  2. view-tolerant模式副作用:引入view_to_reshape可能改变图结构,但该函数是PyTorch内置转换,已在实际模型上验证。
  3. UUID缓存键调整:在gfx950上移除clone_elimination.uuid()可能导致缓存未命中,但此影响已通过仅在gfx950上pop并确保non-gfx950完整键来缓解。
  4. 测试覆盖有限:仅覆盖两种fan-out形状,未涵盖其他可能出现view的场景(如AiterFusedAddRMSFp8GroupQuantPattern的fan-out),但实际模型已验证无回归。

影响范围:仅ROCm平台,且仅AITER可用且使用RocmAiterRMSNormQuantFusionPass的场景(即启用fuse_norm_quant=Truematch_rocm_aiter=True的编译模式)。具体为gfx950(MI355X)上运行DeepSeek-V3.2模型时性能提升约3.3%(TP4, bf16, HIP graphs)。用户影响:DSv3.2用户可直接受益;其他模型(Kimi-K2.5等)融合模式无影响(Replaced 0 patterns)。团队影响:无直接开发负担,但后续应关注其他ROCm平台是否出现类似算子选择问题。

核心编译逻辑变更 gfx950 特定门控 模式匹配依赖图结构 可能影响其他 ROCm 平台(已验证无回归)

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论