Prhub

#39120 [ROCm] Fix cu_seqlens_q off-by-one in AITER FA speculative decode path

原始 PR 作者 Bortlesboat 合并时间 2026-04-20 02:34 文件变更 1 提交数 8 评论 9 代码增减 +2 / -2

执行摘要

修复 ROCm AITER FlashAttention 在推测解码路径中的 cu_seqlens_q off-by-one 错误。

PR body 指出,在 speculative decode 路径中,cu_seqlens_q 被错误切片为 query_start_loc[:num_decodes],但正确的应该是 [:num_decodes + 1],因为 cu_seqlens_q 是累积长度数组,需要 num_seqs + 1 个条目。引用上游 AITER 实现作为验证。

此 PR 值得精读,因为它展示了累积长度数组在 attention 后端中的正确使用模式,对于理解推测解码和 ROCm 集成有帮助。

讨论亮点

审核中,tjtanaa 要求添加上游 AITER 实现链接作为证明,以确保修复的正确性。讨论无争议,修复被批准。

实现拆解

  1. 修改 descale_shape 计算:在 vllm/v1/attention/backends/rocm_aiter_fa.py 中,将 descale_shapeattn_metadata.query_start_loc[:num_decodes].shape[0] - 1 改为直接使用 num_decodes,简化计算并匹配后续逻辑。
  2. 修复 cu_seqlens_q 切片:将 cu_seqlens_q 参数从 attn_metadata.query_start_loc[:num_decodes] 改为 attn_metadata.query_start_loc[: num_decodes + 1],确保提供正确的累积长度数组。
  3. 验证与上游一致:参考 AITER 上游 unified_attention 实现,确认 cu_seqlens_q 需要 num_seqs + 1 个条目。
  4. 无配套改动:本次变更仅涉及核心逻辑文件,没有测试、配置或部署配套改动。
文件 模块 状态 重要度
vllm/v1/attention/backends/rocm_aiter_fa.py 注意力后端 modified 5.07

关键源码片段

vllm/v1/attention/backends/rocm_aiter_fa.py core-logic

这是唯一被修改的文件,包含 AITER FlashAttention 后端的推测解码路径核心逻辑修复。

if decode_max_query_len > 1:
    from aiter.ops.triton.unified_attention import unified_attention
​
    descale_shape = (
        num_decodes, # 修复前 : attn_metadata.query_start_loc[:num_decodes].shape[0] - 1
        key_cache.shape[2],
    )
    unified_attention(
        q=query[:num_decode_tokens],
        k=key_cache,
        v=value_cache,
        out=output[:num_decode_tokens],
        cu_seqlens_q=attn_metadata.query_start_loc[: num_decodes + 1], # 修复前 : [:num_decodes]
        max_seqlen_q=decode_max_query_len,
        seqused_k=attn_metadata.seq_lens[:num_decodes],
        max_seqlen_k=attn_metadata.max_seq_len,
        softmax_scale=self.scale,
        causal=True,
        alibi_slopes=self.alibi_slopes,
        window_size=self.sliding_window,
        block_table=attn_metadata.block_table[:num_decodes],
        softcap=self.logits_soft_cap,
        q_descale=None,
        k_descale=layer._k_scale.expand(descale_shape),
        v_descale=layer._v_scale.expand(descale_shape),
    )
    return

评论区精华

确认 cu_seqlens_q 切片正确性 正确性

tjtanaa 要求添加上游 AITER 实现链接作为验证,确保修复符合上游行为。

结论:修复被批准,链接已添加,无争议。 · 已解决

风险与影响

风险较低,但需注意:

  • 核心路径变更:修改了 attention 后端的推测解码路径,若切片逻辑仍有误,可能影响 ROCm 平台多令牌解码的正确性。
  • 依赖上游一致性:修复基于上游 AITER 实现,若上游变更可能需同步调整。

影响范围有限:仅影响使用 ROCm 平台且启用推测解码(多令牌解码)的场景。修复后,确保 cu_seqlens_qdescale_shape 计算正确,提升系统稳定性。

核心路径变更

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论