Prhub

#40176 [ROCm] Support non-causal attention in ROCM_ATTN

原始 PR 作者 micah-wil 合并时间 2026-04-22 11:57 文件变更 4 提交数 6 评论 13 代码增减 +41 / -13

执行摘要

在 ROCm 注意力后端支持非因果注意力,修复 DFlash 推测解码测试。

根据 PR body,动机是修复 DFlash spec decoding with ROCM_ATTN,因为 PR #38300 添加的测试 test_dflash_speculators_model 在 ROCm 上失败,暴露了 ROCm 后端缺少对查询令牌双向注意力的支持。需要添加非因果注意力功能以使 DFlash 草案正常工作。

该 PR 值得精读,特别是 Triton 内核中注意力掩码逻辑的修改和元数据设计,展示了如何在多后端系统中处理功能标志。建议关注 prefix_prefill.py 中的掩码实现优化,以及 rocm_attn.py 中元数据的扩展方式,这对理解 vLLM 注意力后端架构有较高价值。

讨论亮点
  • 关键错误修复:gemini-code-assist[bot] 指出在 prefix_prefill.py 的非因果路径中,注意力分数未屏蔽超出 cur_batch_query_len 的令牌,可能导致 softmax 分母错误。micah-wil 通过添加 valid_kv 检查修复了掩码逻辑。
  • 代码清理:tjtanaa 建议移除 RocmAttentionMetadatacausal 字段的冗余注释,并避免使用 getattr 以确保性能。micah-wil 更新了代码,直接从 common_attn_metadata.causal 获取值,并移除了错误的 FlashAttentionMetadata 导入。
  • 后端一致性:讨论发现 RocmAiterUnifiedAttentionBackendsupports_non_causal 与实现不一致,因为底层 unified_attention 仅支持因果注意力。micah-wil 覆盖该方法返回 False,并在配置检查中添加错误提示。

实现拆解

  1. 更新 ROCm 注意力元数据:在 vllm/v1/attention/backends/rocm_attn.py 中,为 RocmAttentionMetadata 类添加 causal 字段(默认 True),并在 RocmAttentionMetadataBuilder.build 方法中从 common_attn_metadata 获取该值,以传递注意力因果性。
  2. 启用非因果支持:在 RocmAttentionBackend 类中添加 supports_non_causal 类方法返回 True,表明 ROCM_ATTN 后端支持非因果注意力。
  3. 修改 Triton 内核:在 vllm/v1/attention/ops/prefix_prefill.py 中,为 _fwd_kernel 添加 CAUSAL 参数,并调整注意力掩码逻辑以正确处理因果和非因果情况,确保超出边界的令牌被正确屏蔽。
  4. 处理其他后端:在 vllm/v1/attention/backends/rocm_aiter_unified_attn.py 中,覆盖 supports_non_causal 返回 False,因为该后端底层实现不支持非因果注意力,避免配置冲突并添加错误提示。
  5. 传播参数:在 vllm/v1/attention/ops/chunked_prefill_paged_decode.py 中添加 causal 参数,确保它传递给底层内核调用。
文件 模块 状态 重要度
vllm/v1/attention/backends/rocm_attn.py 注意力后端 modified 6.7
vllm/v1/attention/backends/rocm_aiter_unified_attn.py 注意力后端 modified 5.83
vllm/v1/attention/ops/prefix_prefill.py 注意力操作 modified 5.13
vllm/v1/attention/ops/chunked_prefill_paged_decode.py 注意力操作 modified 3.35

关键符号

RocmAttentionBackend.supports_non_causal RocmAiterUnifiedAttentionBackend.supports_non_causal RocmAttentionMetadataBuilder.build _fwd_kernel context_attention_fwd

关键源码片段

vllm/v1/attention/backends/rocm_attn.py core-logic

核心 ROCm 注意力后端实现,添加因果字段到元数据并启用非因果支持,是修复 DFlash 测试失败的关键。

@dataclass
class RocmAttentionMetadata:
    # ... 其他字段定义
    causal: bool = True # 新增:注意力是否因果,True 为因果(默认),False 为非因果class RocmAttentionBackend(AttentionBackend):
    @classmethod
    def supports_non_causal(cls) -> bool:
        return True # 新增:ROCM_ATTN 后端支持非因果注意力,用于 DFlash 草案
vllm/v1/attention/ops/prefix_prefill.py infrastructure

Triton 内核修改,实现因果和非因果注意力掩码逻辑,是支持双向注意力的核心。

def _fwd_kernel(..., CAUSAL: tl.constexpr = True, ...):
    # ... 其他代码
    if CAUSAL:
        # 因果注意力:仅考虑当前查询之前的键
        attn_mask = valid_kv & (offs_m[:, None] >= (start_n + offs_n[None, :]))
    else:
        # 非因果注意力:允许所有有效令牌间的注意力
        attn_mask = valid_kv
    qk = tl.where(attn_mask, qk, float("-inf")) # 应用掩码,将无效位置设为负无穷

评论区精华

非因果注意力掩码错误修复 正确性

gemini-code-assist[bot] 指出在 prefix_prefill.py 的非因果路径中,未屏蔽超出 cur_batch_query_len 的令牌,可能导致 softmax 分母错误,从而输出不正确。

结论:micah-wil 修复了掩码逻辑,添加 valid_kv 检查以确保超出边界的令牌被标记为 -inf,解决了潜在的正确性问题。 · 已解决

元数据字段定义优化 设计

tjtanaa 建议移除 RocmAttentionMetadata 中 causal 字段的冗余注释,并避免使用 getattr 以确保性能,因为未定义字段的查找可能很慢。

结论:micah-wil 更新了代码,移除注释并直接从 common_attn_metadata.causal 获取值,同时移除了错误的 FlashAttentionMetadata 导入。 · 已解决

后端支持标志一致性 设计

gemini-code-assist[bot] 发现 RocmAiterUnifiedAttentionBackend 的 supports_non_causal 与实现不一致,因为底层 unified_attention 仅支持因果注意力,导致配置冲突。

结论:micah-wil 覆盖 supports_non_causal 返回 False,并在配置检查中添加错误提示,确保用户不会误用非因果功能。 · 已解决

风险与影响

  • 回归风险:因果标志默认值为 True,应保持现有因果注意力行为不变,但需验证所有路径,特别是 Triton 内核中的掩码逻辑更改可能影响其他注意力模式。
  • 性能影响:添加条件分支和额外检查可能轻微增加开销,但在 Triton 内核中通过编译时常量优化,影响有限。
  • 兼容性:确保与 DFlash 推测解码和其他注意力后端的兼容;RocmAiterUnifiedAttentionBackend 显式不支持非因果,可能限制某些用例,但通过错误提示避免误用。
  • 测试覆盖:PR 未添加新测试,但修复了现有测试失败;建议补充单元测试以验证非因果路径的正确性。
  • 用户影响:ROCm 用户现在可以正常使用 DFlash 推测解码,提升模型推理性能和效率。
  • 系统影响:扩展了 ROCm 注意力后端的功能集,支持更复杂的注意力模式,增强了跨平台兼容性。
  • 团队影响:解决了由测试暴露的跨平台问题,提高了代码健壮性,并为未来类似功能扩展提供了参考设计。
核心路径变更 注意力掩码逻辑调整 后端兼容性差异

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论