Prhub

#39242 [ROCm] Add MLA dual RMS norm fusion (Q, KV) pass for DeepSeek/Kimi-K2

原始 PR 作者 rbrugaro-amd 合并时间 2026-04-20 22:56 文件变更 8 提交数 12 评论 19 代码增减 +361 / -1

执行摘要

为 ROCm 平台添加 MLA 双 RMSNorm 融合优化,提升 DeepSeek-V3/Kimi-K2 模型性能。

根据PR body描述,未融合的MLA层需要运行两个独立的RMSNorm调用(分别处理q和kv压缩隐变量),导致每个层有2次内核启动。融合为一个内核可减少启动开销,提升模型推理性能,特别是在DeepSeek-V3和Kimi-K2这类具有61层MLA结构的模型上效果显著。关联Issue #2442(ROCm/aiter)提供了底层的fused_qk_rmsnorm HIP内核支持。

建议工程团队精读此PR,重点关注MLADualRMSNormPattern的模式设计如何动态推导split尺寸,以及VllmFusionPatternMatcherPass的使用范例。对于涉及图优化或硬件特定加速的开发者,此PR展示了如何通过torch.inductor模式匹配安全地融合复杂操作子图,具有较高参考价值。

讨论亮点

核心讨论点

  • 图拓扑顺序风险:gemini-code-assist[bot]指出初始实现中手动遍历节点可能导致输入节点未正确提升,引发编译失败。开发者rbrugaro-amd随后重构为模式匹配方案以规避此问题。
  • 代码结构与最佳实践:Rohan138建议使用torch.inductor的PatternMatcher而非手动迭代,并将pass整合到现有rocm_aiter_fusion.py文件中;ProExpertProg进一步推荐使用VllmFusionPatternMatcherPass基类使代码更简洁,这些建议均被采纳。
  • 配置条件细化:Rohan138对enable_mla_dual_rms_norm_fusion函数仅检查AITer可用性提出疑问,但ProExpertProg认为当前条件足够,未引入额外限制。
  • 文档与命名规范:ProExpertProg请求在文档中添加关于Inductor默认融合的说明;Rohan138指出配置日志中的“AITer”拼写应统一为“AITER”。

实现拆解

  1. 自定义操作注册:在vllm/_aiter_ops.py中新增_fused_mla_dual_rms_norm_impl_fused_mla_dual_rms_norm_fake函数,并通过direct_register_custom_op注册为fused_mla_dual_rms_norm操作,封装AITer的fused_qk_rmsnorm内核。
  2. 模式匹配pass实现:在vllm/compilation/passes/fusion/rocm_aiter_fusion.py中新增MLADualRMSNormPattern类(继承VllmPatternReplacement),定义模式识别和替换逻辑,将连接的子图(包含split操作和两个rms_norm调用)重写为单个融合操作;MLADualRMSNormFusionPass类(继承VllmFusionPatternMatcherPass)负责在FX图中应用该模式。
  3. 配置系统集成:在vllm/config/vllm.py中添加enable_mla_dual_rms_norm_fusion函数,根据AITer可用性控制融合开关;在vllm/config/compilation.pyPassConfig中新增fuse_mla_dual_rms_norm布尔字段,并添加ROCm平台检查;在vllm/compilation/passes/pass_manager.py中注册该pass到pass流水线。
  4. 测试配套:新增tests/compile/passes/test_fuse_mla_dual_rms_norm.py单元测试,包含MLADualRMSNormTestModel模型和test_fuse_mla_dual_rms_norm测试函数,验证模式匹配、操作替换和数值正确性。
  5. 文档更新:更新docs/design/fusions.mddocs/design/optimization_levels.md,记录该融合pass的使用说明和配置项。
文件 模块 状态 重要度
vllm/compilation/passes/fusion/rocm_aiter_fusion.py 编译融合 modified 8.36
tests/compile/passes/test_fuse_mla_dual_rms_norm.py 测试覆盖 added 7.3
vllm/_aiter_ops.py 操作注册 modified 7.52
vllm/config/vllm.py 配置系统 modified 5.94
vllm/config/compilation.py 配置系统 modified 5.48

关键符号

MLADualRMSNormPattern MLADualRMSNormFusionPass _fused_mla_dual_rms_norm_impl enable_mla_dual_rms_norm_fusion test_fuse_mla_dual_rms_norm

关键源码片段

vllm/compilation/passes/fusion/rocm_aiter_fusion.py core-logic

核心实现文件,包含 MLA 双 RMSNorm 融合的模式匹配和 pass 逻辑,定义了如何识别和重写 FX 图。

class MLADualRMSNormPattern(
    VllmPatternReplacement[..., tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
):
    """
    融合MLA注意力中配对的q_a_layernorm和kv_a_layernorm到AITER的fused_qk_rmsnorm HIP内核。
    目标FX图模式(未融合,vllm_ir阶段):
        gemm -> split_with_sizes([q_dim, kv_dim])
            +-- q_c -> vllm_ir.rms_norm(q_c, q_w, eps)
            +-- kv_lora -> split_with_sizes([kv_c_dim, k_pe_dim])
                    +-- kv_c -> vllm_ir.rms_norm(kv_c, kv_w, eps)
                    +-- k_pe
    """
​
    def __init__(self, epsilon: float) -> None:
        self._epsilon = epsilon # 设置 epsilon 参数,用于模式匹配中的 RMSNorm 计算
​
    def get_inputs(self) -> list[torch.Tensor]:
        # 提供虚拟输入用于模式验证,尺寸任意但保持维度一致性
        q_dim, kv_c_dim, k_pe_dim = 8, 4, 2
        return [
            self.empty_bf16(5, q_dim + kv_c_dim + k_pe_dim), # projected 输入
            self.empty_bf16(q_dim), # q_weight
            self.empty_bf16(kv_c_dim), # kv_weight
        ]
​
    @property
    def pattern(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
        eps = self._epsilon
        def _pattern(projected: torch.Tensor, q_weight: torch.Tensor, kv_weight: torch.Tensor):
            q_dim = q_weight.shape[0] # 动态获取 q 维度
            kv_dim = projected.shape[-1] - q_dim # 计算 kv 总维度
            kv_c_dim = kv_weight.shape[0] # 动态获取 kv_c 维度
            k_pe_dim = kv_dim - kv_c_dim # 计算 k_pe 维度
            q_c, kv_lora = projected.split([q_dim, kv_dim], dim=-1) # 第一次 split
            kv_c, k_pe = kv_lora.split([kv_c_dim, k_pe_dim], dim=-1) # 第二次 split
            q_normed = vllm.ir.ops.rms_norm(q_c, q_weight, eps) # q 的 RMSNorm
            kv_normed = vllm.ir.ops.rms_norm(kv_c, kv_weight, eps) # kv 的 RMSNorm
            return q_normed, kv_normed, k_pe # 返回三个输出
        return _pattern
​
    @property
    def replacement(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
        eps = self._epsilon
        def _replacement(projected: torch.Tensor, q_weight: torch.Tensor, kv_weight: torch.Tensor):
            q_dim = q_weight.shape[0]
            kv_dim = projected.shape[-1] - q_dim
            kv_c_dim = kv_weight.shape[0]
            k_pe_dim = kv_dim - kv_c_dim
            q_c, kv_lora = projected.split([q_dim, kv_dim], dim=-1)
            kv_c, k_pe = kv_lora.split([kv_c_dim, k_pe_dim], dim=-1)
            # 使用融合操作替换两个独立的 RMSNorm 调用
            q_normed, kv_normed = torch.ops.vllm.fused_mla_dual_rms_norm(
                q_c, q_weight, kv_c, kv_weight, eps, eps
            )
            return q_normed, kv_normed, k_pe
        return _replacement
tests/compile/passes/test_fuse_mla_dual_rms_norm.py test-coverage

单元测试文件,验证融合 pass 的正确性,包括模式匹配、操作替换和数值精度。

class MLADualRMSNormTestModel(torch.nn.Module):
    """
    最小化模型,复现MLA双RMSNorm模式:
        linear -> split([q_dim, kv_dim])
            +-- q_c -> rms_norm(q_w, eps) -> linear
            +-- kv_lora -> split([kv_c_dim, k_pe_dim])
                    +-- kv_c -> rms_norm(kv_w, eps)
                    +-- k_pe
    """
    def __init__(self, hidden_size: int, q_dim: int = 1536, kv_c_dim: int = 512, k_pe_dim: int = 64, eps: float = 1e-6):
        super().__init__()
        self.q_dim = q_dim
        self.kv_dim = kv_c_dim + k_pe_dim
        self.kv_c_dim = kv_c_dim
        self.k_pe_dim = k_pe_dim
        self.proj = torch.nn.Linear(hidden_size, q_dim + self.kv_dim, bias=False) # 投影层
        self.q_norm = RMSNorm(q_dim, eps=eps) # q 的 RMSNorm 层
        self.kv_norm = RMSNorm(kv_c_dim, eps=eps) # kv 的 RMSNorm 层
        self.q_b_proj = torch.nn.Linear(q_dim, hidden_size, bias=False) # 后续线性层
​
    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = torch.relu(x) # 避免输入直接作为模式节点
        projected = self.proj(x) # 投影操作
        q_c, kv_lora = projected.split([self.q_dim, self.kv_dim], dim=-1) # 第一次 split
        kv_c, k_pe = kv_lora.split([self.kv_c_dim, self.k_pe_dim], dim=-1) # 第二次 split
        q_normed = self.q_norm(q_c) # 原始 q RMSNorm
        kv_normed = self.kv_norm(kv_c) # 原始 kv RMSNorm
        q_out = self.q_b_proj(q_normed) # 后续处理
        return q_out, kv_normed, k_pe # 返回三个输出,用于数值比较
​
    def ops_in_model_before(self):
        return [torch.ops.vllm_ir.rms_norm.default] # 融合前期望的操作
​
    def ops_in_model_after(self):
        return [torch.ops.vllm.fused_mla_dual_rms_norm.default] # 融合后期望的操作

评论区精华

图拓扑顺序与模式匹配重构 设计

gemini-code-assist[bot] 指出初始手动遍历节点实现可能存在输入节点顺序问题,导致编译失败;Rohan138 建议使用 torch.inductor 模式匹配替代,并整合到现有文件。

结论:开发者重构代码,采用 PatternMatcher 和 VllmFusionPatternMatcherPass,解决了拓扑风险并遵循最佳实践。 · 已解决

配置条件与文档更新 设计

Rohan138 对 enable_mla_dual_rms_norm_fusion 仅检查 AITer 可用性提出疑问,ProExpertProg 认为当前足够;同时 reviewers 要求更新相关设计文档。

结论:配置逻辑保持原样,文档已更新以记录融合 pass 的详细信息。 · 已解决

代码风格与命名规范 style

ProExpertProg 提出代码格式建议(如简化返回语句),Rohan138 指出配置日志中“AITer”拼写应统一为“AITER”。

结论:开发者采纳格式建议并修正拼写,确保代码一致性。 · 已解决

风险与影响

  1. 平台依赖性风险:融合仅适用于ROCm平台且依赖外部AITer库(PR #2442),若AITer不可用或版本不兼容,融合将自动禁用,但可能引发用户困惑。
  2. 模式匹配健壮性风险MLADualRMSNormPattern依赖于特定的FX图结构(如split尺寸和rms_norm调用顺序),若模型图结构变化(例如不同MLA变体),可能导致匹配失败或误匹配。
  3. 数值精度风险:融合内核fused_qk_rmsnorm与原始两个RMSNorm操作的数值等价性依赖AITer实现,虽经单元测试验证,但在边缘情况(如极端epsilon值)下仍需监控。
  4. 编译时性能风险:模式匹配增加了图遍历开销,可能轻微影响编译时间,但鉴于融合仅在优化等级≥O1时触发,影响可控。
  1. 用户影响:对于使用DeepSeek-V3或Kimi-K2模型的ROCm用户,在启用优化(默认O1及以上)后可获得约1.02倍的吞吐量提升(基于PR中性能数据),无需修改模型代码或配置。
  2. 系统影响:减少内核启动次数,降低GPU调度开销,有助于提高硬件利用率;但仅影响MLA注意力层,不改变其他模型组件或API接口。
  3. 团队影响:引入了新的编译pass和配置项,增加了ROCm专用优化模块的复杂性,需团队在后续维护中熟悉模式匹配框架和AITer集成。
依赖外部 AITer ROCm 平台特定 模式匹配复杂性 数值精度验证

关联 Issue

#2442 add fused_qknorm hip kernel

完整报告

参与讨论