Prhub

#26318 [diffusion][jit_kernel] perf: varlen FA fast path for USPAttention masked branch

原始 PR 作者 mispa-ms 合并时间 2026-05-28 21:26 文件变更 6 提交数 9 评论 14 代码增减 +649 / -3

执行摘要

Varlen FA 加速 USPAttention masked 路径,Qwen-Image 推理提速 15%+

PyTorch SDPA 在非 None mask 下无法使用 FlashAttention,回退到 SM80 的 cutlassF 内核,在 Blackwell B200 上每个 attention 调用比原生 FA4(flash_fwd_sm100)慢约 7 倍,成为去噪循环的主要瓶颈。上游 Flash Attention 长期不支持密集 mask(Dao-AI Lab/flash-attention#409, #1990),因此需要一种 workaround 来提升 diffusion 模型(如 Qwen-Image)的推理性能。

该 PR 值得精读,尤其是对从事 Transformer inference 性能优化的工程师。核心设计模式(Triton 融合减少 launch、metadata 预计算复用、显式契约确保兼容性)具有很高的参考价值。新增的测试用例可作为 Triton 内核测试的范例。建议关注后续是否将该模式推广到其他 attention 变体(如 cross-attention、DPO 等)。

讨论亮点
  • 语义变更担忧:BBuf 指出 varlen 路径将 masked query rows 的输出置零,与 SDPA 的 key mask 语义(mask 仅作用于 key,query 保持)不同,对于其他使用 USPAttention 的调用者可能存在风险。讨论结果是:不再从 SDPA 内部推断,而是要求调用者显式提供 attn_mask_meta(通过 build_varlen_mask_meta 构建),明确接受“masked rows 将被 drop 并置零”的契约。若不提供,则回退到 SDPA,确保默认行为不变。
  • 端到端测试要求:BBuf 要求增加对比 varlen 路径与 SDPA 的等价性测试。作者补充了 test_varlen_uspattn_equivalence.py,验证有效行 match SDPA(FA 容差内),无效行精确为零。
  • 设备一致性检查:BBuf 建议添加 attn_mask.device == q.device 断言。作者在对应的 commit 中已添加此检查。
  • 环境变量开关:提供 SGLANG_VARLEN_FA=0 可完全回退到 SDPA,确保安全降级。

实现拆解

  1. 新增 Triton 融合 pack/scatter 内核:在 python/sglang/jit_kernel/diffusion/triton/varlen_pack_pad.py 中实现 fused_pack_qkvfused_scatter_to_padded 两个融合操作。fused_pack_qkv 通过一个 Triton kernel 将 Q/K/V 按 indices gather 到连续内存(原来需要 3 次 index_select),fused_scatter_to_padded 通过另一个 Triton kernel 将 packed 输出写回 padded 布局,无效位置填 0(原来需要 zeros + index_copy_)。同时提供 build_inv_indices 辅助函数用于从 pack indices 生成反向查找表。

  2. USPAttention 集成 varlen FA 快速路径:在 python/sglang/multimodal_gen/runtime/layers/attention/layer.py 中新增 build_varlen_mask_meta 函数,从 [B, S] 的 bool/int 掩码计算出 cu_seqlensindicesinv_indicesmax_seqlen 等元数据。USPAttention.forward_prepare_sdpa_mask 方法中,当提供了 attn_mask_meta 且满足形状条件时,走 varlen FA 路径:先 fused_pack_qkv,再 flash_attn_varlen_func,最后 fused_scatter_to_padded。若 indices 为空则回退到全零输出。环境变量 SGLANG_VARLEN_FA 可全局禁用。

  3. Qwen-Image 模型适配:在 python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py 中,于 QwenImageTransformer2DModel.forward 中预计算 attn_mask_meta(调用 build_varlen_mask_meta(joint_mask))并通过 cross_attention_kwargs 传递到每个 block,实现每请求一次 meta 计算,各 denoise block 复用。同时更新 QwenImageJointBlock.forward 接收 attn_mask_meta 参数并传给 USPAttention

  4. 导出新接口:在 python/sglang/multimodal_gen/runtime/layers/attention/__init__.py 中导出 build_varlen_mask_meta,方便外部调用。

  5. 测试配套

    • test_varlen_pack_pad.py:25 个用例覆盖 bf16/fp16、7 种生产相似形状(含 zero-text 边界)、非连续输入、空 mask,通过 torch.equal 验证融合内核与 PyTorch 参考(index_select / zeros+index_copy_)的比特一致。
    • test_varlen_uspattn_equivalence.py:端到端验证 varlen 路径与 SDPA 在有效行上对齐(FA 容差内),无效行输出精确为零。
    • 测试用例注册到 CI suite base-b-kernel-unit-1-gpu-largenightly-kernel-1-gpu
文件 模块 状态 重要度
python/sglang/jit_kernel/diffusion/triton/varlen_pack_pad.py JIT 内核 added 9.13
python/sglang/multimodal_gen/runtime/layers/attention/layer.py 注意力层 modified 7.56
python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py 模型逻辑 modified 6.43
python/sglang/jit_kernel/tests/diffusion/test_varlen_pack_pad.py 测试 added 8.03
python/sglang/jit_kernel/tests/diffusion/test_varlen_uspattn_equivalence.py 测试 added 7.93
python/sglang/multimodal_gen/runtime/layers/attention/__init__.py 注意力层 modified 4.54

关键符号

build_varlen_mask_meta fused_pack_qkv fused_scatter_to_padded build_inv_indices _fused_pack_qkv_kernel _fused_scatter_to_padded_kernel USPAttention.forward _prepare_sdpa_mask QwenImageJointBlock.forward QwenImageTransformer2DModel.forward

关键源码片段

python/sglang/jit_kernel/diffusion/triton/varlen_pack_pad.py core-logic

核心变更文件,新增两个融合 Triton 内核(pack 和 scatter),替代 5 次 PyTorch launch,是性能优化的核心。

"""Fused Triton pack/scatter kernels for the varlen mask path.Used by ``USPAttention.forward`` masked branch to gather Q/K/V at valid
positions and scatter the FA output back to the dense ``[B, S, H, D]`` layout.
"""from __future__ import annotationsimport torch
import triton
import triton.language as tl
​
​
@triton.jit
def _fused_pack_qkv_kernel(
    Q_ptr, K_ptr, V_ptr,
    Q_unpad_ptr, K_unpad_ptr, V_unpad_ptr,
    indices_ptr,
    HD, # H * D, 展平后的特征维度
    src_row_stride, # Q/K/V 中行之间的步长(B*S 行的下一行)
    dst_row_stride, # Q_unpad/K_unpad/V_unpad 中的行步长
    BLOCK_HD: tl.constexpr,
):
    """每个 packed 行一个程序;将 Q[src]、K[src]、V[src] 复制到目标行。"""
    out_row = tl.program_id(0)
    src_row = tl.load(indices_ptr + out_row).to(tl.int64)
    cols = tl.arange(0, BLOCK_HD)
    col_mask = cols < HD
    src_offset = src_row * src_row_stride + cols
    dst_offset = out_row * dst_row_stride + cols
    # 一次性加载 Q、K、V
    q_val = tl.load(Q_ptr + src_offset, mask=col_mask)
    k_val = tl.load(K_ptr + src_offset, mask=col_mask)
    v_val = tl.load(V_ptr + src_offset, mask=col_mask)
    tl.store(Q_unpad_ptr + dst_offset, q_val, mask=col_mask)
    tl.store(K_unpad_ptr + dst_offset, k_val, mask=col_mask)
    tl.store(V_unpad_ptr + dst_offset, v_val, mask=col_mask)
​
​
def fused_pack_qkv(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, indices: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """将 ``[B, S, H, D]`` 的 Q/K/V 按 ``indices`` 打包成 ``[total_valid, H, D]``。    ``indices`` 是 int64 类型的扁平 ``B*S`` 位置,指示每个保留 token 的位置。
    非连续输入会在内部转换为连续。
    """
    assert q.shape == k.shape == v.shape, "Q/K/V 形状必须一致"
    assert q.dtype == k.dtype == v.dtype, "Q/K/V 数据类型必须一致"
    assert q.dim() == 4, "Q/K/V 必须是 [B, S, H, D]"
    assert indices.dtype in (torch.int32, torch.int64)
    q = q.contiguous()
    k = k.contiguous()
    v = v.contiguous()
    bs, seq, num_heads, head_dim = q.shape
    hd = num_heads * head_dim
    n_valid = indices.shape[0]
    # 空 mask 保护:FA varlen 不接受零长度输入
    if n_valid == 0:
        return (q.new_empty(0, num_heads, head_dim),
                k.new_empty(0, num_heads, head_dim),
                v.new_empty(0, num_heads, head_dim))
    block_hd = triton.next_power_of_2(hd)
    q_flat = q.view(bs * seq, hd)
    k_flat = k.view(bs * seq, hd)
    v_flat = v.view(bs * seq, hd)
    q_unpad = torch.empty(n_valid, hd, dtype=q.dtype, device=q.device)
    k_unpad = torch.empty(n_valid, hd, dtype=k.dtype, device=k.device)
    v_unpad = torch.empty(n_valid, hd, dtype=v.dtype, device=v.device)
    with torch.get_device_module().device(q.device):
        _fused_pack_qkv_kernel[(n_valid,)](
            q_flat, k_flat, v_flat,
            q_unpad, k_unpad, v_unpad,
            indices,
            hd, q_flat.stride(0), q_unpad.stride(0),
            BLOCK_HD=block_hd,
        )
    return (q_unpad.view(n_valid, num_heads, head_dim),
            k_unpad.view(n_valid, num_heads, head_dim),
            v_unpad.view(n_valid, num_heads, head_dim))
python/sglang/multimodal_gen/runtime/layers/attention/layer.py core-logic

修改文件:新增 `build_varlen_mask_meta` 函数和 USPAttention 的 varlen FA 快速路径,是整个优化在模型层面的集成点。

import os
from sglang.jit_kernel.diffusion.triton.varlen_pack_pad import (
    build_inv_indices, fused_pack_qkv, fused_scatter_to_padded,
)
from sglang.jit_kernel.flash_attention import flash_attn_varlen_func
from sglang.multimodal_gen.runtime.layers.attention.backends import (
    flash_attn as _fa_backend,
)# 环境变量开关:设置为 "0" 禁用 varlen FA 快速路径
_VARLEN_FA_ENABLED = os.environ.get("SGLANG_VARLEN_FA", "1") != "0"
​
​
def build_varlen_mask_meta(key_mask: torch.Tensor) -> dict:
    """从 ``[B, S]`` 的 key mask 构建 varlen FA 元数据。    返回 ``cu_seqlens``、``indices``、``inv_indices``、``max_seqlen``。
    将结果通过 ``joint_attention_kwargs`` 传入可启用 USPAttention 的 varlen
    FA 快速路径,该路径会将 masked query 行的输出置零——仅当下游丢弃或忽略
    这些行时才能使用。
    """
    assert key_mask.dim() == 2, "key_mask 必须为 [B, S]"
    bs, seq = key_mask.shape
    bool_mask = key_mask.to(dtype=torch.bool)
    valid_lens = bool_mask.sum(dim=1, dtype=torch.int32)
    # 展平后的有效位置索引
    indices = bool_mask.reshape(-1).nonzero(as_tuple=False).flatten()
    cu_seqlens = torch.zeros(bs + 1, dtype=torch.int32, device=key_mask.device)
    cu_seqlens[1:] = torch.cumsum(valid_lens, dim=0)
    inv_indices = build_inv_indices(indices, bs * seq)
    return {
        "cu_seqlens": cu_seqlens,
        "indices": indices,
        "inv_indices": inv_indices,
        "max_seqlen": seq, # 上界;FA varlen 实际使用 cu_seqlens 确定范围
    }
​
​
# 在 USPAttention.forward 中(_prepare_sdpa_mask 方法):
if attn_mask_meta is not None:
    # 快速路径:使用 varlen FA
    indices = attn_mask_meta["indices"]
    if indices.shape[0] == 0:
        # 全 False mask,直接返回全零
        return torch.zeros_like(q)
    q_unp, k_unp, v_unp = fused_pack_qkv(q, k, v, indices)
    out_unp = flash_attn_varlen_func(
        q=q_unp, k=k_unp, v=v_unp,
        cu_seqlens_q=attn_mask_meta["cu_seqlens"],
        cu_seqlens_k=attn_mask_meta["cu_seqlens"],
        max_seqlen_q=attn_mask_meta["max_seqlen"],
        max_seqlen_k=attn_mask_meta["max_seqlen"],
        softmax_scale=softmax_scale,
        causal=False,
        ver=_fa_backend.fa_ver,
    )
    return fused_scatter_to_padded(out_unp, attn_mask_meta["inv_indices"], bs, seq)
python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py data-contract

修改文件:在 QwenImageTransformer2DModel 中预计算 `attn_mask_meta` 并传递给每个 block,是实际触发优化的调用方。

from sglang.multimodal_gen.runtime.layers.attention import (
    USPAttention,
    build_varlen_mask_meta,
)# 在 QwenImageTransformer2DModel.forward 中,构建 joint mask 后:
joint_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
block_attention_kwargs["attn_mask"] = joint_mask
# 预计算 varlen 元数据,每请求只计算一次,所有 block 复用
block_attention_kwargs["attn_mask_meta"] = build_varlen_mask_meta(joint_mask)# 在 QwenImageJointBlock.forward 中接收并传递:
attn_mask_meta = cross_attention_kwargs.get("attn_mask_meta")
out = usp_attn(
    query, key, value,
    attn_mask=attn_mask,
    attn_mask_meta=attn_mask_meta, # 传入 varlen 元数据
    num_replicated_prefix=seq_len_txt,
)

评论区精华

语义变更:masked query rows 被置零 vs key mask 语义 设计

BBuf 指出 varlen 路径将 masked query rows 输出置零,与 SDPA 的 key mask 语义不同,对于其他调用者可能存在风险。建议要么只打包 K/V 保留 query,要么通过显式标记来明确这一假设。

结论:采用显式契约:调用者必须通过 `build_varlen_mask_meta` 构建 `attn_mask_meta` 并传入,明确接受“masked rows 被 drop 并置零”的语义。非 mask 路径回退到 SDPA,不影响已有行为。 · 已解决

建议添加 varlen 路径与 SDPA 的端到端等价测试 测试

BBuf 要求增加端到端测试,对比新 varlen 路径与旧 SDPA 路径的输出,验证正确性。

结论:作者新增 `test_varlen_uspattn_equivalence.py`,验证有效行在 FA 容差范围内匹配 SDPA,无效行精确为零。 · 已解决

设备一致性检查建议 正确性

BBuf 建议添加 `attn_mask.device == q.device` 断言以确保 mask 和输入在同一个设备上。

结论:作者在后续提交中添加了此项检查(PR #26318 的 commit 中体现)。 · 已解决

风险与影响

  • 语义改变风险:masked query rows 输出被置零而非 SDPA 的正常注意力输出。如果下游需要这些行的注意力输出,将导致错误。已通过显式 attn_mask_meta 契约和文档说明来缓解,且仅在传入 meta 时生效,不影响已有调用者。
  • 兼容性风险:varlen 路径严格限制在 bool/int 2D [B, S] mask、Q/K/V shape 一致、cuda 设备、bf16/fp16。float 相加 mask 和 cross-attention 形状保持原有 SDPA 路径,不会错误应用优化。
  • 数值差异风险:FlashAttention 与 SDPA 存在数值容差(bf16 下 rtol=1e-2, atol=5e-2),测试已验证有效行在容差内。但某些极端 case 可能超出容差,需关注。
  • 回归风险:非 mask 模型(FLUX.2-klein-4B)经测试零回归,且优化路径本就不会触发。环境变量可全局关闭。
  • 用户影响:使用 Qwen-Image 等 diffusion 模型的用户将获得 14-21% 的推理加速(饱和批),显著降低去噪循环延迟。其他模型用户不受影响。
  • 系统影响:新增 Triton 内核编译开销(首次运行时),但运行后无额外开销。减少 GPU launch 次数,对整体系统负载有正面影响。
  • 团队影响:提供了一种可重复的模式(build_varlen_mask_meta + 融合 pack/scatter),未来可推广到其他需要 mask 的 attention 场景。
语义变更(masked rows zero-filled) 仅适用于 bool/int 2D mask FA vs SDPA 数值容差 依赖环境变量控制

关联 Issue

#409 does flash attention support attention mask?
#1990 Does FA3 varlen func support pad between sequences?

完整报告

参与讨论