Prhub

#37228 [ROCM][Bugfix] Use correct stride in cp_mha_gather_cache_kernel for hybrid model (#37228)

原始 PR 作者 jennyyyyzhen 合并时间 2026-03-27 01:33 文件变更 1 提交数 1 评论 6 代码增减 +25 / -6

执行摘要

修复 ROCM 后端在混合模型下 KV 缓存非连续内存访问错误,避免注意力输出 NaN。

混合模型(如Qwen3.5)的KV缓存布局为交错模式[K_0][V_0][K_1][V_1]...,而非原始假设的连续布局[K_all][V_all]update_hybrid_attention_mamba_layout使用as_strided_()重排KV块导致内存非连续,但内核使用硬编码指针算术,从而读取错误内存位置,产生垃圾值和NaN。PR body明确指出此问题,并引用测试命令#35925验证修复后无损坏响应。

该PR值得精读,尤其关注Triton内核中处理非连续内存的通用模式。设计决策亮点:采用传递stride而非仅第一维stride,以预防未来其他维度非连续导致的静默错误。建议团队审查其他类似内核是否存在相同假设,并优先修复shuffle路径问题。

讨论亮点

review中主要讨论点:

  1. gemini-code-assist[bot]指出修复不完整,仅适用于'NHD'缓存格式,而'SHUFFLE'格式路径仍使用硬编码步长,可能导致非连续缓存时错误内存访问。
  2. yuankaichen-amd询问测试模型和stride传递细节,作者回复测试了qwen3.5,并解释传递所有stride是更标准的Triton内核写法,可预防未来其他维度非连续导致的静默错误。结论:修复被接受,但shuffle路径问题被标记为TODO待后续处理。

实现拆解

修改仅涉及一个文件vllm/v1/attention/backends/rocm_aiter_fa.py。关键改动点:

  1. cp_mha_gather_cache_kernel函数签名中添加6个stride参数(k/v_cache_stride0/1/2)。
  2. 在内核指针计算中,用stride参数替换硬编码的num_heads * head_size * PAGE_SIZE等计算。
  3. 在调用函数cp_mha_gather_cache中,通过key_cache.stride()value_cache.stride()获取实际步长并传递给内核。
  4. do_kv_cache_update函数中添加TODO注释,指出shuffle路径同样存在此问题需后续修复。
文件 模块 状态 重要度
vllm/v1/attention/backends/rocm_aiter_fa.py attention/backends modified 8.0

关键符号

cp_mha_gather_cache_kernel cp_mha_gather_cache do_kv_cache_update

分析完成后,这里会展示 LLM 生成的相对完整源码片段和详细注释。

评论区精华

修复不完整:shuffle 路径未处理 正确性

gemini-code-assist[bot] 指出 stride 参数仅用于 'NHD' 格式,'SHUFFLE' 格式仍用硬编码计算,可能导致非连续缓存时错误内存访问。

结论:问题被确认,在 do_kv_cache_update 中添加 TODO 注释,待后续修复。 · pending

stride 传递设计决策 设计

yuankaichen-amd 询问为何传递所有 stride 而非仅第一维,作者解释这是更标准的 Triton 内核写法,可预防未来其他维度非连续导致的静默错误。

结论:采用传递所有 stride 的方案,增强代码健壮性。 · 已解决

风险与影响

技术风险:

  1. 回归风险:修改涉及核心注意力内核的指针计算,若stride传递或使用错误,可能导致内存访问越界或数据损坏。
  2. 兼容性风险:仅修复了'NHD'格式路径,'SHUFFLE'格式路径未修复,使用混合模型时若启用shuffle可能仍出错。
  3. 测试覆盖不足:PR body提到测试了#35925命令,但未提及是否有自动化测试覆盖此场景,可能依赖手动验证。风险具体位置:rocm_aiter_fa.py中的指针计算逻辑变更。

影响范围:

  1. 用户:使用ROCM后端运行混合模型(如Qwen3.5)的用户将修复注意力输出NaN问题,提升模型推理稳定性。
  2. 系统:仅影响ROCM后端的特定内核,对CUDA或其他后端无影响。
  3. 团队:揭示了Triton内核中硬编码内存假设的通用问题,可能促使其他类似内核的审查和修复。影响程度:中等,修复特定但关键的内存访问错误,避免混合模型推理失败。
核心路径变更 部分路径未修复 缺少自动化测试

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论