Prhub

#42880 [ROCm] Guard AITER GDN decode fast path by layout

原始 PR 作者 tuukkjs 合并时间 2026-05-19 02:56 文件变更 1 提交数 4 评论 5 代码增减 +8 / -4

执行摘要

修复 Qwen3.5 在 ROCm 上 GDN 精度回归

PR #40711 为 GatedDeltaNet 注意力添加了 AITER Triton 融合解码路径,显著减少了内核启动开销。但该融合内核假设输入使用 Qwen3-Next 的交叠 GQA 布局。Qwen3.5 使用非交叠 [q, k, v, z][b, a] 投影布局,直接传入融合 reshape 内核会读取错误列,导致 GSM8K 精度从 ~0.92 骤降至 ~0.0。

值得精读,尤其是对 ROCm 平台上 Triton 内核布局假设敏感的推理引擎开发者。它展示了一个典型的风险:当优化内核假设特定数据布局时,不匹配会导致静默精度崩溃。建议在集成测试中增加对多种布局的端到端验证。

讨论亮点

Reviewer tpopp 确认 PR 本身无问题,但询问如何防止未来类似问题(尤其是 Triton 内核对于布局的假设无法通过类型系统强制)。Reviewer tjtanaa 要求使用完整 GSM8K 数据集和大 batch size(128)验证,以消除噪声。作者 tuukkjs 提供了完整运行结果,证实修复有效。最终 reviewer tjtanaa 批准。

实现拆解

  1. 定位入口:在 gdn_linear_attn.py_forward_core_rocm 方法中,找到 decode 快速路径的调用点(原先仅检查 spec_sequence_masks is Nonenum_prefills == 0num_decodes > 0)。
  2. 增加布局守卫:在条件判断前添加 self.gqa_interleaved_layout 检查,同时更新方法文档字符串以明确说明该路径仅对 interleaved-GQA 布局生效。
  3. 回落逻辑:当 self.gqa_interleaved_layoutFalse 时,不再调用 _forward_core_decode_fast,转而执行原有的 unpack + _forward_core 路径,保证非交叠布局的分割和重排正确。
  4. 验证:通过在完整 GSM8K 数据集上对比前后精度(batch_size=128)确认修复有效,精度从几乎 0% 恢复至 ~86.6%。
文件 模块 状态 重要度
vllm/model_executor/layers/mamba/gdn_linear_attn.py 注意力层 modified 6.42

关键符号

_forward_core_rocm

关键源码片段

vllm/model_executor/layers/mamba/gdn_linear_attn.py data-contract

核心修复文件,修改了 `_forward_core_rocm` 方法的 decode 快速路径条件。

    def _forward_core_rocm(
        self,
        qkvz: torch.Tensor,
        ba: torch.Tensor,
        z_out: torch.Tensor,
        core_attn_out: torch.Tensor,
    ):
        # ... docstring 已更新 ...
​
        forward_context = get_forward_context()
        attn_metadata_raw = forward_context.attn_metadata
​
        if attn_metadata_raw is None:
            # warmup 路径不变
            ...
            return
​
        assert isinstance(attn_metadata_raw, dict)
        attn_metadata = attn_metadata_raw[self.prefix]
        assert isinstance(attn_metadata, GDNAttentionMetadata)
​
        # 关键修复:仅在 interleaved 布局下使用 AITER 快速路径
        # Qwen3-Next (gqa_interleaved_layout=True) 继续走融合内核
        # Qwen3.5 (gqa_interleaved_layout=False) 回落通用路径
        if (
            self.gqa_interleaved_layout # <-- 新增守卫
            and attn_metadata.spec_sequence_masks is None
            and attn_metadata.num_prefills == 0
            and attn_metadata.num_decodes > 0
        ):
            return self._forward_core_decode_fast(
                qkvz=qkvz,
                ba=ba,
                z_out=z_out,
                core_attn_out=core_attn_out,
                attn_metadata=attn_metadata,
            )
​
        # 非 interleaved 或混合批次走通用路径
        core_attn_out.zero_()
        z_out.zero_()
        num_tokens_all = qkvz.shape[0]
        mixed_qkv, z, b, a = self.prepare_gdn_attention_core_inputs(
            qkvz, ba, num_tokens_all
        )
        z_out[:] = z
        self._forward_core(
            mixed_qkv=mixed_qkv,
            b=b,
            a=a,
            core_attn_out=core_attn_out,
        )

评论区精华

如何防止未来类似布局假设问题 设计

Reviewer tpopp 询问如何通过类型或自动化手段防止 Triton 内核布局假设导致的静默错误。

结论:暂无具体方案,纯 PyTorch 能感知布局,但包装的 Triton 内核缺乏布局信息传播机制。 · unresolved

需要更大规模验证以消除噪声 测试

Reviewer tjtanaa 要求使用完整 GSM8K 和大 batch_size (128) 验证。

结论:作者提供了完整验证结果,精度从 ~0% 恢复至 ~86.6%。 · 已解决

风险与影响

风险较低:

  • 仅修改了一个条件判断,新增的 self.gqa_interleaved_layout 守卫是通过现有实例属性访问,无外部依赖。
  • 对于 Qwen3-Next(已设置 gqa_interleaved_layout=True),行为完全不变,不会影响其性能。
  • 对于其他可能使用非交叠布局的未来模型,该守卫自动将其导向通用路径,安全性高。
  • 缺少对应的单元测试,修复依赖手动验证;建议在后续 PR 中添加布局敏感的融合内核测试。

直接影响 ROCm 上 Qwen3.5 模型的 GDN 注意力精度,从几乎不可用恢复至正常水平。Qwen3-Next 及其它已设置 gqa_interleaved_layout=True 的模型无影响。修改仅 1 个文件,对系统其他模块无副作用。

精度回归修复 缺少测试覆盖 核心路径变更

关联 Issue

#40711 [Aiter][ROCm] gdn_linear_attn kernel fusion

完整报告

参与讨论