Prhub

#38615 [ROCm] Fix aiter persistent mode mla with q/o nhead<16 for kimi-k2.5 tp8

原始 PR 作者 wufann 合并时间 2026-04-03 18:54 文件变更 1 提交数 2 评论 4 代码增减 +4 / -3

执行摘要

修复 ROCm Aiter MLA 后端在注意力头数小于 16 时元数据分配与内核输入形状不匹配的问题。

PR body指出,AiterMLAImpl已包含头填充机制:当num_heads为4或8时,q张量会通过repeat_interleave填充到16个头再传递给内核。但persistent mode MLA实现为原始头数(如8)分配缓冲区,而内核接收的是16个头的q张量,导致预分配的元数据缓冲区与实际内核输入之间的形状不匹配。这影响了Kimi-K2.5模型在TP=8下的正确运行。

该PR值得精读,尤其关注头填充机制与元数据分配的一致性设计。对于ROCm平台开发者和多模态模型用户,可学习如何调试形状不匹配问题及利用max函数简化边界条件处理。

讨论亮点

review中gemini-code-assist[bot]建议简化逻辑,使用max(16, self.num_heads)来提高鲁棒性并与AiterMLAImpl实现保持一致。tjtanaa指出修复与_needs_head_repeat条件相关,并验证了gemini反馈的有效性。wufann接受了反馈并更新了代码。讨论焦点在于确保元数据分配与内核输入形状的一致性,无重大争议,结论明确采纳简化方案。

实现拆解

仅修改了vllm/v1/attention/backends/mla/rocm_aiter_mla.py文件。关键改动是将self._num_attention_heads的计算从vllm_config.model_config.get_num_attention_heads(vllm_config.parallel_config)改为max(16, self.num_heads)。这确保了元数据分配的头数至少为16,与AiterMLAImpl中头填充逻辑(_needs_head_repeat条件)保持一致,避免形状不匹配。

文件 模块 状态 重要度
vllm/v1/attention/backends/mla/rocm_aiter_mla.py attention/backends/mla modified 8.0

关键符号

__init__

分析完成后,这里会展示 LLM 生成的相对完整源码片段和详细注释。

评论区精华

元数据分配头数简化 正确性

gemini-code-assist[bot] 建议使用 max(16, self.num_heads) 简化逻辑,提高鲁棒性并与 AiterMLAImpl 实现保持一致。

结论:采纳建议,更新代码使用 max(16, self.num_heads)。 · 已解决

修复与 _needs_head_repeat 条件关联 设计

tjtanaa 指出修复与 AiterMLAImpl 中的 _needs_head_repeat 条件相关,验证了 gemini 反馈。

结论:确认修复正确性,关联代码逻辑。 · 已解决

风险与影响

风险较低。变更仅影响ROCm Aiter MLA后端在注意力头数小于16的场景,修复了形状不匹配的bug。潜在风险包括:

1) 对头数>=16的现有场景无影响,但需确保max(16, self.num_heads)不会意外改变其他配置;
2) 依赖self.num_heads的正确初始化,但该变量已在基类MLACommonMetadataBuilder中定义,风险可控。

影响范围有限但关键:修复了Kimi-K2.5等模型在ROCm平台TP=8下使用Aiter MLA后端时的运行错误,使GSM8K评估准确率从故障状态恢复至93.4%。仅影响使用该特定后端且头数小于16的用户,对大多数其他配置无影响。有助于提升ROCm平台多模态模型支持的稳定性。

形状不匹配修复 平台特定逻辑

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论