Prhub

#40763 [Bug] Fix GLM-5.1 running error on ROCm platform

原始 PR 作者 qli88 合并时间 2026-04-25 03:54 文件变更 3 提交数 1 评论 3 代码增减 +68 / -27

执行摘要

修复 GLM-5.1 在 ROCm 上的 MLA 头部填充问题

GLM-5.1 模型在 ROCm 平台上运行失败,原因是 AITER MLA 实现要求 num_heads>=16,而 GLM-5.1 在使用特定 tensor_parallel_size 时头部数量可能小于16。同时 AITER 存在一个除零错误,需要规避。

该 PR 值得仔细阅读,特别是 AiterMLAHelper 类的设计——将特定后端的特殊需求集中管理,避免散落在各个 forward 方法中。建议未来在 AITER 上游修复后及时移除 workaround(参见代码中的 TODO)。

讨论亮点

gemini-code-assist[bot] 提出了三点建议:

  • 修正 docstring 中的逻辑错误:16 // num_heads == 0 应为 16 % num_heads == 0
  • 更新错误消息以反映支持除16的倍数外的约数(如4,8)
  • 在 is_valid_num_heads 中处理 num_heads<=0 的情况,避免 ZeroDivisionError
    这些建议均已被采纳。

实现拆解

  1. 创建 AiterMLAHelper 类 (vllm/v1/attention/backends/mla/rocm_aiter_mla.py): 封装了头部有效性检查、实际头部数量获取、Q 填充和 O 解填充等静态方法。
  2. 常规 MLA 后端适配: 在 AiterMLAImpl 中,将原有的内联头部重复逻辑替换为 AiterMLAHelper 的方法,包括检查头部有效性、填充 Q、创建输出张量、最终解填充 O。
  3. 稀疏 MLA 后端适配 (vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py): 导入 AiterMLAHelper,在构造函数中调用检查方法,在前向传播中使用填充后的 Q 和实际头部数量,并解填充输出。
  4. 修复 AITER 除零错误 (vllm/v1/attention/ops/rocm_aiter_mla_sparse.py): 将 ChunkQ 参数设置为 heads 以规避除零问题,并添加 TODO 注释标记待后续移除。
文件 模块 状态 重要度
vllm/v1/attention/backends/mla/rocm_aiter_mla.py 注意力后端 modified 8.41
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py 注意力后端 modified 5.99
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py 注意力操作 modified 3.08

关键符号

AiterMLAHelper.check_num_heads_validity AiterMLAHelper.is_valid_num_heads AiterMLAHelper.get_actual_mla_num_heads AiterMLAHelper.get_mla_padded_q AiterMLAHelper.get_mla_unpadded_o AiterMLAImpl.forward_mqa ROCMAiterMLASparseImpl._forward_bf16_kv ROCMAiterMLASparseImpl.forward_mqa

关键源码片段

vllm/v1/attention/backends/mla/rocm_aiter_mla.py core-logic

核心变更文件,新增 AiterMLAHelper 类封装头部填充逻辑,并重构 AiterMLAImpl 使用该辅助类。

# class AiterMLAHelper: 封装 AITER MLA 头部填充逻辑
# AITER 要求 num_heads >= 16,若小于 16 则通过 repeat_interleave 填充
# 待下游计算完毕后,再通过切片还原class AiterMLAHelper:
    _AITER_MIN_MLA_HEADS: Final = 16
​
    @staticmethod
    def check_num_heads_validity(num_heads: int):
        # 校验头部数量是否有效(必须是 16 的倍数或约数)
        assert AiterMLAHelper.is_valid_num_heads(num_heads), (
            f"Aiter MLA requires that num_heads be multiples or divisors of 16, "
            f"but provided {num_heads} number of heads.\n"
            f"Try adjusting tensor_parallel_size value."
        )
​
    @staticmethod
    def is_valid_num_heads(num_heads: int) -> bool:
        # 当 num_heads >= 16 时必须是 16 的倍数;否则 16 必须能被 num_heads 整除
        return (
            num_heads % AiterMLAHelper._AITER_MIN_MLA_HEADS == 0
            if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS
            else AiterMLAHelper._AITER_MIN_MLA_HEADS % num_heads == 0
        )
​
    @staticmethod
    def get_actual_mla_num_heads(num_heads: int) -> int:
        # 实际用于内核的头部数:至少为 16
        return max(num_heads, AiterMLAHelper._AITER_MIN_MLA_HEADS)
​
    @staticmethod
    def get_mla_padded_q(num_heads: int, q: torch.Tensor) -> torch.Tensor:
        # 若 num_heads < 16,沿头部维度 repeat_interleave 到 16
        return (
            q
            if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS
            else q.repeat_interleave(
                AiterMLAHelper._AITER_MIN_MLA_HEADS // num_heads, dim=1
            )
        )
​
    @staticmethod
    def get_mla_unpadded_o(num_heads: int, o: torch.Tensor) -> torch.Tensor:
        # 将填充后的输出切回原始头部数(每隔 factor 取一个)
        return (
            o
            if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS
            else o[:, :: AiterMLAHelper._AITER_MIN_MLA_HEADS // num_heads, :]
        )# 在 AiterMLAImpl 的 forward_mqa 中的使用示例
mla_padded_q = AiterMLAHelper.get_mla_padded_q(self.num_heads, q)
mla_num_heads = AiterMLAHelper.get_actual_mla_num_heads(self.num_heads)
o = torch.empty(B, mla_num_heads, self.kv_lora_rank, ...)
# ... 调用 AITER 内核计算 ...
return AiterMLAHelper.get_mla_unpadded_o(self.num_heads, o)
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py dependency-wiring

稀疏 MLA 后端适配,导入并使用 AiterMLAHelper 进行头部检查、Q 填充和 O 解填充。

# 稀疏 MLA 后端的头部适配关键变更
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
    AiterMLAHelper,
)class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata]):
    def __init__(self, ..., **mla_args):
        AiterMLAHelper.check_num_heads_validity(num_heads) # 新增头部校验
        # ...
​
    def _forward_bf16_kv(self, q, kv_c_and_k_pe_cache, topk_indices, attn_metadata):
        mla_num_heads = AiterMLAHelper.get_actual_mla_num_heads(self.num_heads)
        output = torch.empty([num_tokens, mla_num_heads, self.kv_lora_rank], ...) # 使用实际头部数
        # ... 调用 AITER 内核 ...
        return AiterMLAHelper.get_mla_unpadded_o(self.num_heads, output) # 解填充
​
    def forward_mqa(self, q, kv_c_and_k_pe_cache, attn_metadata, layer):
        mla_padded_q = AiterMLAHelper.get_mla_padded_q(self.num_heads, q) # 对 Q 填充
        attn_out = self._forward_bf16_kv(mla_padded_q, ...)
        return attn_out, None
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py infrastructure

修复 AITER 除零错误,将 ChunkQ 参数设置为 heads 以规避问题。

# 修复 AITER 除零错误的 workaround
# 当 heads 较小时,AITER 内部计算 ChunkQ 可能为 0 导致除零错误
# 此处直接传递 heads 作为 ChunkQ,待 AITER PR 合并后移除deepgemm_fp8_paged_mqa_logits_stage1(
    q_fp8, kv_cache_fp8,
    ...
    ChunkQ=heads, # 原为 ChunkQ=auto 或缺失,现显式传递
)

评论区精华

docstring 逻辑错误 正确性

gemini-code-assist[bot] 指出 docstring 中 `16 // num_heads == 0` 应为 `16 % num_heads == 0`。

结论:作者接受建议,已修正。 · 已解决

错误消息未反映实际支持范围 正确性

gemini-code-assist[bot] 指出错误消息仅提及 multiples of 16,但实际支持 divisors(如 4,8),应更新。

结论:作者接受建议,已更新。 · 已解决

is_valid_num_heads 需要处理非正输入 正确性

gemini-code-assist[bot] 建议添加对 num_heads<=0 的检查,避免 ZeroDivisionError。

结论:作者接受建议,已添加。 · 已解决

风险与影响

  1. 性能影响: 头部填充/解填充会引入额外内存复制和计算开销,但仅影响头部数小于16的场景,影响可控。
  2. 正确性风险: padding 逻辑可能在非标准 MLA 实现中导致非对齐错误,但通过 lm_eval 验证了 GLM-5.1 和 DeepSeek-R1 的精度在可接受范围内。
  3. 兼容性: 该修改仅限于 ROCm 平台的 AITER 后端,不影响其他 GPU 后端。
  1. 用户: 使得 GLM-5.1 在 ROCm 平台(TP=8 或 TP=4)上可以正常运行。
  2. 系统: 增加了 padding 和 unpadding 步骤,但仅在小头部数场景生效,整体开销低。
  3. 团队: 引入 AiterMLAHelper 类统一了头部处理逻辑,便于后续维护和扩展。
依赖上游 AITER 修复 padding 引入性能开销 仅影响 ROCm 平台

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论