Prhub

#26206 [GDN] Optimize prefill QKV split dispatch

原始 PR 作者 BBuf 合并时间 2026-06-02 16:48 文件变更 4 提交数 8 评论 104 代码增减 +276 / -12

执行摘要

融合 Triton kernel 优化 GDN prefill QKV 拆分

通过 PyTorch profiler 定位到 GDN prefill 中 post-conv QKV 拆分操作(torch.split(...).view(...).contiguous())在 H200 上消耗了 18.97ms,是 prefill 中一个显著热点。该操作在 FLA 输入防护中触发多个 aten::copy_ kernel launch。为了降低开销,作者实现了一个融合 kernel,在单次 launch 内完成拆分和重新布局。同时,为了避免 chunk_delta_h kernel 的多配置 autotune 损坏状态池,但又要允许调优,将该 kernel 的硬编码配置改为环境变量参数化。

该 PR 值得精读,尤其关注以下设计决策:

  • 融合 kernel 的 stride 支持:同时支持连续和非连续输入,避免额外 contiguous() 调用。
  • 单配置 autotune 与环境变量:在安全性(避免多配置破坏状态池)和灵活性(允许调优)之间取得平衡。
  • 渐进式优化:先从最明显的开销切入,通过量化数据验证收益。后续可进一步优化 strided 输入路径。
讨论亮点
  • 非连续输入兼容性:gemini-code-assist 指出 mixed_qkvforward_extend 中通常是 transposed 后的非连续 tensor,若 fused kernel 不考虑 stride 可能读取错误。作者已在 kernel 中通过 MIXED_QKV_STRIDE_TMIXED_QKV_STRIDE_D 处理,但 reviewer 建议进一步确认或调用 .contiguous()。最终该问题被判定为已解决。
  • chunk_delta_h 默认配置回归:commit 记录显示,早期 commit 将 num_stages 默认设为 4,导致在 H100/H200 上因共享内存不足(K=V=256 需要 279KB,而最大 232KB)测试失败。后续将默认值回退到 2,并通过环境变量提供覆盖能力。
  • CI 状态:PR 经过多次 rerun-ci 和 lint 修复后通过 extra CI,最终获得 yuan-luo 的批准。

实现拆解

  1. 新增 fused QKV split Triton kernelpython/sglang/jit_kernel/triton/gdn_fused_proj.py):
    • 定义 fused_qkv_split_gdn_prefill_kernel JIT kernel,接受输入 tensor mixed_qkv 及其步长(stride),通过 tl.load 按 token 读入,再根据编译时常量计算出的偏移分别 tl.store 到 Q、K、V 输出 tensor。
    • 上层包装函数 fused_qkv_split_gdn_prefill 根据输入确定输出 shape,以 seq_len 为一维 grid 启动 kernel,BLOCK_SIZE 自动设为 qkv_dim 的 next power of 2。
  2. 集成到 GDN backendpython/sglang/srt/layers/attention/linear/gdn_backend.py):
    • forward_extend 方法中,通过条件判断(is_cuda()qkv_dim <= MAX_FUSED_QKV_SPLIT_DIM = 8192)选择是否调用 fused 函数。若条件不满足(非 CUDA 或维度超限),则 fallback 到原有的 torch.split + view 路径。
    • 由于 fused 函数期望的输入可能来自 causal_conv1d_fn 的结果,该结果在 transpose 后可能非连续,kernel 内部通过 stride 参数处理。
  3. 参数化 chunk_delta_h kernel 配置python/sglang/srt/layers/attention/fla/chunk_delta_h.py):
    • 引入三个环境变量 SGLANG_GDN_CHUNK_H_BVSGLANG_GDN_CHUNK_H_NUM_WARPSSGLANG_GDN_CHUNK_H_NUM_STAGES,默认值分别为 32、4、2(与之前硬编码一致),使用 int(os.getenv(key, default)) 读取。
    • 保持单 config 的 @triton.autotune 装饰器不变,避免多 config 导致 autotune 阶段损坏状态池。配置更新注释,说明环境变量的用途。
  4. 新增微基准测试benchmark/bench_linear_attention/bench_gdn_qkv_split.py):
    • 实现 split_reference 模拟旧路径,fused_qkv_split_gdn_prefill 调用新路径,使用 torch.testing.assert_close 做正确定性验证。
    • 分别对 contiguous 和 strided 布局进行性能测试,输出 baseline 和 fused 的延迟及加速比。
文件 模块 状态 重要度
python/sglang/jit_kernel/triton/gdn_fused_proj.py JIT 内核 modified 7.9
python/sglang/srt/layers/attention/linear/gdn_backend.py 注意力层 modified 6.79
python/sglang/srt/layers/attention/fla/chunk_delta_h.py FLA 层 modified 6.06
benchmark/bench_linear_attention/bench_gdn_qkv_split.py 基准测试 added 8.26

关键符号

fused_qkv_split_gdn_prefill_kernel fused_qkv_split_gdn_prefill forward_extend chunk_gated_delta_rule_fwd_kernel_h_blockdim64 split_reference

关键源码片段

python/sglang/srt/layers/attention/fla/chunk_delta_h.py configuration

将硬编码 Triton 配置参数化为环境变量,实现安全调优。

# 环境变量配置,默认值与之前硬编码一致
GDN_CHUNK_H_BV = int(os.getenv("SGLANG_GDN_CHUNK_H_BV", "32"))
GDN_CHUNK_H_NUM_WARPS = int(os.getenv("SGLANG_GDN_CHUNK_H_NUM_WARPS", "4"))
GDN_CHUNK_H_NUM_STAGES = int(os.getenv("SGLANG_GDN_CHUNK_H_NUM_STAGES", "2"))@triton.autotune(
    configs=[
        triton.Config(
            {"BV": GDN_CHUNK_H_BV},
            num_warps=GDN_CHUNK_H_NUM_WARPS,
            num_stages=GDN_CHUNK_H_NUM_STAGES,
        )
    ],
    key=["H", "K", "V", "BT", "USE_GK", "NT_BUCKET"],
    **autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
    k, v, w, v_new, g, gk, h, initial_state, initial_state_indices,
    cu_seqlens, chunk_offsets, T,
    H: tl.constexpr, Hg: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
    BT: tl.constexpr, BV: tl.constexpr,
    USE_G: tl.constexpr, USE_GK: tl.constexpr,
    USE_INITIAL_STATE: tl.constexpr, INPLACE_UPDATE: tl.constexpr,
    SAVE_NEW_VALUE: tl.constexpr, IS_VARLEN: tl.constexpr, NT_BUCKET: tl.constexpr,
):
    # ... (kernel 实现主体不变 )

评论区精华

非连续输入兼容性 正确性

gemini-code-assist 评论:在 forward_extend 中 mixed_qkv 通常来自 transpose(0,1),是非连续的,而 fused kernel 假设连续布局可能导致错误。建议调用 .contiguous() 或让 kernel 处理 strides。

结论:作者已在 kernel 中添加 stride 参数处理,该评论标记为已处理。PR 最终获批。 · 已解决

num_stages 默认值与共享内存限制 正确性

早期 commit 将 num_stages 默认设为 4,但在 H100/H200 上 test_chunk_gated_delta_rule::test_dim_256x256 因共享内存不足崩溃(需要 279KB,可用 232KB)。

结论:后续 commit 将默认 num_stages 回退到 2,并改为环境变量可调,以保证兼容现有硬件。 · 已解决

风险与影响

  • 非连续输入性能退化:当前 fused kernel 对 strided 输入加速仅 1.03x,与 contiguous 的 2.38x 差距明显。若生产环境中输入常为非连续,优化效果有限。可通过在调用 fused 前调用 .contiguous() 来规避,但会增加一次拷贝开销。
  • chunk_delta_h 配置兼容性:环境变量默认值(num_stages=2)兼容现有 NVIDIA GPU,但若用户随意设置为 4,可能在某些 shape 上因共享内存不足导致 kernel 启动失败。建议在文档中说明界限。
  • 单平台验证:所有性能数据仅在 B200 上采集,H100/H200 等其他 GPU 上的效果未明确验证。chunk_delta_h 的共享内存限制在 H100/H200 上已经暴露问题。
  • 缺少直接单元测试:benchmark 脚本包含了正确性检查,但未作为自动化测试集成。改动核心逻辑未添加独立单元测试,可能遗漏边界情况。
  • 用户影响:无 API 变化,使用 Qwen3.6 GDN 模型的用户可自动享受性能提升(Chat 吞吐 +2.7%,TTFT -17%)。其他 GDN 模型(如 Kimi-Linear)也可能受益。
  • 系统影响:新增环境变量 SGLANG_GDN_CHUNK_H_* 可调整 GPU 线程配置,用户可根据硬件情况微调。chunk_delta_h kernel 的配置调整不会影响其他模块。
  • 团队影响:新增的 fused kernel 和 benchmark 脚本需要维护,但其位于独立文件,耦合度低。
非连续输入性能退化 chunk_delta_h 配置兼容性 缺少直接单元测试

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论