Prhub

#25045 Fix FA3 cross-attention batched-decode for per-request varlen encoder

原始 PR 作者 zsj555 合并时间 2026-05-26 14:26 文件变更 1 提交数 1 评论 2 代码增减 +33 / -24

执行摘要

修复 FA3 跨注意力对变长编码器的批处理支持

FA3 注意后端的跨注意力元数据构建器仅支持 encoder_lens.numel() == 1,在非 CUDA Graph 路径中断言检查,CUDA Graph 重放路径则对所有请求复用 encoder_lens[0]。对于同一批次中混合不同编码器长度的多模态编码器-解码器模型,只有第 0 个请求获得正确的跨注意力,其余请求读取错误的 KV 缓存位置,产生乱码。在 bs=128 时,乱码输出比例高达 25-36%。

建议精读该 PR,尤其是 init_forward_metadata 中 fancy indexing 的构造方式。这是一个典型的注意力元数据构建 bug 修复,设计简洁,且验证充分,对于理解 SGLang 中跨注意力的批处理实现有很好的参考价值。

讨论亮点

该 PR 的 review 评论数量为 0,审核人 mickqian 直接批准,没有留下讨论记录。

实现拆解

  1. init_forward_metadata 方法中移除 assert encoder_lens.numel() == 1 断言及相关注释。
  2. init_forward_metadata 方法中,为自注意力(文本)的 page_table 构建使用基于每个请求的偏移量:通过 fancy indexing 构造 text_col = encoder_lens[i] + arange_text,替换原来使用单一 encoder_max_seq_len_k 全局偏移的方式。
  3. init_forward_metadata_replay_cuda_graph 方法中,将 encoder_max_seq_len_kencoder_lens[0] 改为 encoder_lens.max().item(),并将完整的 encoder_lens[:bs] 复制到 int32 和 cu_seqlens 缓冲区(原仅复制 encoder_lens[:1])。同时调整 page_table 的复制范围以匹配每个请求的实际长度。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/flashattention_backend.py 注意力 modified 6.9

关键源码片段

python/sglang/srt/layers/attention/flashattention_backend.py core-logic

唯一修改的文件,包含 FA3 注意力后端中跨注意力元数据构建的正确性修复,是修复的核心位置。

# python/sglang/srt/layers/attention/flashattention_backend.py
# 在 init_forward_metadata 方法中(非 CUDA Graph 路径的跨注意力处理部分)if forward_batch.encoder_lens is not None:
    # 移除原有的 assert 断言和注释,现支持变长编码器批处理
    metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
    metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
        torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
        (1, 0),
    )
    metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
​
    # 跨注意力 page_table:每行对应一个请求,cache_seqlens 中已
    # 记录每行实际编码器长度,超出 encoder_lens[i] 的垃圾数据不会被读取
    metadata.encoder_page_table = self.req_to_token_pool.req_to_token[
        forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
    ]
​
    # 自注意力(文本) page_table:文本偏移量基于每个请求的
    # encoder_lens[i],而非单一最大值。使用 fancy index 进行按行偏移
    text_max = metadata.max_seq_len_k
    arange_text = torch.arange(
        text_max, device=forward_batch.req_pool_indices.device
    )
    # text_col 形状 : (bs, max_seq_len_k),每行从 encoder_lens[i] 开始偏移
    text_col = forward_batch.encoder_lens.long().unsqueeze(1) + arange_text.unsqueeze(0)
    text_row = forward_batch.req_pool_indices.unsqueeze(1).expand(-1, text_max)
    metadata.page_table = self.req_to_token_pool.req_to_token[
        text_row, text_col
    ]

# init_forward_metadata_replay_cuda_graph 方法中(CUDA Graph 重放路径)if encoder_lens is not None:
    # 支持每请求变长编码器 ( 例如 MossVL 不同图片尺寸 )
    metadata.encoder_max_seq_len_k = int(encoder_lens.max().item())
    # 复制完整 batch 的 encoder_lens,而非仅复制第一个
    metadata.encoder_lens_int32[:bs].copy_(encoder_lens[:bs].to(torch.int32))
    metadata.encoder_cu_seqlens_k[1 : bs + 1].copy_(
        torch.cumsum(metadata.encoder_lens_int32[:bs], dim=0, dtype=torch.int32)
    )
    # 限制复制范围为实际请求数 bs
    metadata.encoder_page_table[:bs, : metadata.encoder_max_seq_len_k].copy_(
        self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]
    )
    # 更新常规 page_table 时同样使用每请求偏移
    # ... ( 通过类似 fancy-index gather 实现 )

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

风险较低。变更仅影响 FA3 后端的跨注意力路径,且改动逻辑在 PR body 中通过对比测试验证(bs 8~128 下乱码率从 12-36% 降为 0%)。但变更未新增测试,依赖现有的 CI 测试覆盖;此外 fancy indexing 的引入可能带来极其微小的性能开销,但作者在 PR body 中通过速度测试证明无 measurable change。

影响仅限于使用 FA3 解码注意力后端、且模型包含跨注意力的多模态编码器-解码器模型(如 MossVL)。对于使用 FlashInfer 后端的用户无影响。修复后,这类模型在批处理推理和强化学习 rollout 中的正确性得到保证。

缺少测试覆盖 核心路径变更

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论