Prhub

#21985 perf: eliminate attention DtoD copy by passing pre-allocated output to FA

原始 PR 作者 jasperjiaguo 合并时间 2026-04-25 03:05 文件变更 6 提交数 10 评论 8 代码增减 +38 / -5

执行摘要

消除注意力层 DtoD 拷贝,每层节省约 14μs

在 unified_attention_with_output 中,注意力后端内部新分配输出张量,然后通过 .copy_() 将其拷贝到预分配的输出缓冲区,导致每个注意力层产生约 14μs 的 Memcpy DtoD。对于 28 层模型,每次前向传递累积约 392μs。vLLM 通过直接传递 out= 参数避免了这一开销。本 PR 旨在消除这一冗余拷贝。

值得精读。本 PR 展示了如何通过 PyTorch 的 out 参数和 op schema 别名标注消除不必要的张量拷贝,是性能优化的经典案例。团队内的推理引擎开发人员应关注其中的设计权衡(如用 forward_batch 属性而非 kwargs 传递输出),以应用到其他相似场景。

讨论亮点

审核人 Qiaolin-Yu 已批准,无其他评论。主要技术决策包括:1)不使用 kwargs 传递 output(会破坏不支持 **kwargs 的后端如 FlashInfer),改为在 forward_batch 上存储 _attn_output;2)_attn_output 需要按 real_num_tokens 切片以匹配 FA 的 shape 校验;3)Op schema 的别名标注是消除拷贝的关键。全程无争议。

实现拆解

  1. radix_attention.py:在调用 attn_backend.forward 前将 output[:real_num_tokens] 赋值给 forward_batch._attn_output;调用后仅当返回张量 data_ptr 与 output 不同时才执行 copy_(),避免了强制拷贝。
  2. flashattention_backend.py:在 forward_extendforward_decode 中,从 forward_batch._attn_output 读取并 reshape(.view)为 FA 期望的 shape (-1, tp_q_head_num, v_head_dim),作为 out 参数传递给 flash_attn_with_kvcacheflash_attn_varlen_func
  3. sgl-kernel/python/sgl_kernel/flash_attn.pyflash_attn_with_kvcacheflash_attn_varlen_func 新增 out=None 参数,并将其传递给 torch.ops.sgl_kernel.fwdout 位置。
  4. jit_kernel 包装层(flash_attention.py 和 flash_attention_v3.py):透传 out 参数到下层调用。
  5. sgl-kernel/csrc/flash_extension.cc:修改 op schema,将 out 类型从 Tensor? 改为 Tensor(a!)?,返回类型从 Tensor 改为 Tensor(a!),让 PyTorch dispatch 知道返回别名 out,从而避免防御性拷贝。
文件 模块 状态 重要度
python/sglang/srt/layers/radix_attention.py 注意力层 modified 6.13
python/sglang/srt/layers/attention/flashattention_backend.py 注意力层 modified 6.37
sgl-kernel/python/sgl_kernel/flash_attn.py 内核封装 modified 5.19
python/sglang/jit_kernel/flash_attention.py JIT 内核 modified 4.67
python/sglang/jit_kernel/flash_attention_v3.py JIT 内核 modified 4.67
sgl-kernel/csrc/flash_extension.cc C++ 扩展 modified 4.9

关键符号

unified_attention_with_output forward_extend forward_decode flash_attn_with_kvcache flash_attn_varlen_func

关键源码片段

python/sglang/srt/layers/radix_attention.py core-logic

核心入口:在该文件中将预分配的输出绑定到 forward_batch._attn_output,并实现条件拷贝逻辑,是消除冗余拷贝的关键决策点。

# python/sglang/srt/layers/radix_attention.py
# unified_attention_with_output 函数中的关键变更部分
# 在调用注意力后端之前,将预分配的输出切片传递给后端
forward_batch._attn_output = output[:real_num_tokens] # 切片匹配 FA 的 query 长度# ... 调用后端 forward ...# 后端可能直接写入 output(data_ptr 相同),或者返回新张量
if ret.data_ptr() != output.data_ptr():
    # 仅当后端没有直接写入 output 时才拷贝(保障正确性)
    output[:real_num_tokens].view(ret.shape).copy_(ret)
python/sglang/srt/layers/attention/flashattention_backend.py core-logic

FA 后端实现:在该文件中提取 _attn_output 并作为 out 参数传递给所有 FA 调用路径,覆盖 forward_extend 和 forward_decode。

# python/sglang/srt/layers/attention/flashattention_backend.py
# forward_extend 开头从 forward_batch 中提取预分配输出
_fa_out = (
    forward_batch._attn_output.view(-1, layer.tp_q_head_num, layer.v_head_dim)
    if getattr(forward_batch, "_attn_output", None) is not None
    else None
)
# ... 后续原有逻辑 ...
# 在调用 flash_attn_with_kvcache 或 flash_attn_varlen_func 时传递 out
result = flash_attn_with_kvcache(
    q=...,
    k_cache=key_cache,
    v_cache=value_cache,
    # ... 其他参数 ...
    out=_fa_out, # 直接写入预分配缓冲区
)

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

回归风险较低:条件拷贝路径 (if ret.data_ptr() != output.data_ptr()) 保持兼容性,其他后端不受影响。潜在风险包括:1)_attn_output 为 None 时不会传递 out,行为与改造前一致;2)CUDA graph 捕获场景需要正确切片;3)非 FA 后端(如 FlashInfer、Triton)不会使用 out 参数,因此无影响。性能上,通过 29MB 的 DtoD 拷贝每层消除约 14μs,但仅在 FA3 后端生效。

对使用 FA3 后端的用户:前向性能提升 ~1.5%(低尾部延迟场景更明显);对使用其他后端的用户:无影响;对开发者:需要确保新添加的 out 参数在所有 FA 调用路径中正确传递。不改变 API 或配置,用户透明。

核心路径变更 兼容性(非 FA 后端) CUDA graph 切片

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论