Prhub

#43817 [ROCm] Add attention sink support to AITer flash attention backend

原始 PR 作者 sphinx07 合并时间 2026-05-30 18:13 文件变更 3 提交数 2 评论 13 代码增减 +25 / -5

执行摘要

ROCm AITer Flash Attention 后端支持 attention sink

为 ROCm 平台上的 AITer 后端启用 attention sink 机制,以支持需要保留早期 token 信息的模型(如 GPT-OSS-120B)。作者在评论中提到该变更已在 MI350x 上验证。

建议精读 rocm_aiter_fa.py 中 decode 路径的内核切换逻辑,这是一个典型的「功能开关驱动内核选择」模式。建议作者补充对 AITer 版本的兼容性处理,并添加至少一个单元测试验证 sinks 路径不被意外绕过。

讨论亮点

tjtanaa 询问了模型兼容性和 AITer 库版本问题(v0.1.13/0.1.14),作者回应已使用 GPT-OSS-120B 在 MI350x 上测试,但尚未验证公共 PyPI 版本是否包含 sink_ptr 参数,建议未来可用 try/except 做兼容处理。该讨论未实际修改代码,最终合并时未引入兼容性检查。

实现拆解

  1. 声明后端能力:在 AiterFlashAttentionBackend 类中新增 supports_sink 类方法,返回 True,表明该后端支持 attention sink。
  2. 透传 sink 参数:在 vllm/_aiter_ops.pyflash_attn_varlen_func 包装函数签名中添加 sink_ptr 参数(默认 None),并传递给底层 aiter.flash_attn_varlen_func
  3. 存储 sinks 实例:在 AiterFlashAttentionImpl.__init__ 中新增 sinks 参数,保存到 self.sinks
  4. 路径贯通:在 prefill、extend、decode 各 forward 路径中,将 self.sinks 作为 sink_ptrsinks 参数传递给对应内核函数(flash_attn_varlen_funcunified_attention)。
  5. 自动切换内核:在 decode 路径中,当 sinks 不为 None 时,放弃使用 pa_fwd_asmpaged_attention_v1(它们不支持 sinks),改为使用 unified_attention。同时更新了相关断言信息。
  6. 文档同步:在 docs/design/attention_backends.md 中,将 ROCM_AITER_FA 行的 Sinks 列从 ❌ 改为 ✅。
文件 模块 状态 重要度
vllm/v1/attention/backends/rocm_aiter_fa.py 注意力 modified 6.91
vllm/_aiter_ops.py 操作包装 modified 4.89
docs/design/attention_backends.md 文档 modified 1.32

关键符号

supports_sink flash_attn_varlen_func AiterFlashAttentionImpl.__init__ AiterFlashAttentionImpl.forward

关键源码片段

vllm/v1/attention/backends/rocm_aiter_fa.py core-logic

核心变更文件,新增 supports_sink 声明、sinks 参数存储,以及 prefill/extend/decode 各路径的 sink 透传和内核切换逻辑。

# vllm/v1/attention/backends/rocm_aiter_fa.py
class AiterFlashAttentionBackend(AttentionBackend):
    # ...
    @classmethod
    def supports_sink(cls) -> bool:
        return True # 声明该后端支持 attention sink 机制class AiterFlashAttentionImpl(AttentionImpl):
    def __init__(self, ..., sinks: torch.Tensor | None = None):
        # ...
        self.sinks = sinks # 存储 sinks 张量供后续 forward 使用
​
    def forward(self, ...):
        # ...
        # decode 路径中:当 sinks 激活时切换到 unified_attention
        if (
            self.sliding_window[0] != -1
            or decode_max_query_len > 1
            or self.sinks is not None # 新增条件:sinks 存在时也走 unified_attention
        ):
            # 使用 unified_attention,因其支持 sinks 参数
            unified_attention(..., sinks=self.sinks)
        else:
            # 原 ASM paged attention 路径(不支持 sinks)
            rocm_aiter_ops.pa_fwd_asm(...)
vllm/_aiter_ops.py core-logic

包装函数 flash_attn_varlen_func 新增 sink_ptr 参数,向下游 aiter 库透传。

# vllm/_aiter_ops.py
class AiterOps:
    @staticmethod
    def flash_attn_varlen_func(
        q, k, v,
        cu_seqlens_q, cu_seqlens_k,
        max_seqlen_q, max_seqlen_k,
        min_seqlen_q=None,
        dropout_p=0.0,
        softmax_scale=None,
        causal=False,
        window_size=None,
        alibi_slopes=None,
        return_lse=False,
        out=None,
        sink_ptr: torch.Tensor | None = None, # 新增参数
    ):
        from aiter import flash_attn_varlen_func
        return flash_attn_varlen_func(
            ...,
            sink_ptr=sink_ptr, # 透传给底层库
        )

评论区精华

模型和 AITer 库版本兼容性 question

tjtanaa 询问该 PR 针对哪个模型以及是否与 aiter v0.1.13 兼容。作者回复已在 GPT-OSS-120B(MI350x)上测试,但未验证 PyPI v0.1.14 是否包含 sink_ptr 参数,建议未来可用 try/except 兼容。

结论:当前直接传递参数,未做版本兼容处理,合并前未修改。 · 已解决

风险与影响

  1. 外部依赖兼容性:AITer 公共 PyPI 版本 v0.1.14 可能不包含 sink_ptr 参数,会导致 ImportErrorTypeError。当前代码未做版本检查或 try/except 保护。
  2. 性能退化:当 sinks 启用时,decode 路径从 ASM paged attention 切换到 unified_attention,可能引入性能下降(unified_attention 是更通用的 Triton 内核)。
  3. 测试覆盖不足:无配套测试文件,回归风险依赖 CI 隐式覆盖。

影响范围:仅影响 ROCm 平台上使用 ROCM_AITER_FA 后端的用户,且需 AITer 库支持 sink 参数。新增的 supports_sink 声明确保非 ROCm 平台或不同后端不受影响。文档同步更新使能力表准确。

外部依赖版本兼容风险 缺少测试覆盖 性能退化风险

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论