Prhub

#26588 Optimize Gemma4 H200 MoE and extend attention

原始 PR 作者 BBuf 合并时间 2026-06-06 14:14 文件变更 4 提交数 10 评论 98 代码增减 +337 / -35

执行摘要

优化 Gemma4 H200 MoE 与 extend attention 性能

Gemma4 模型拥有 E=128, N=704 的独特 MoE 结构,现有 MoE Triton 配置未针对该 shape 优化,且 Hopper extend attention block size 默认值在 Lq 129-256 区间并非最优。MTP 推测解码时 QKV RMSNorm 的计算成为小批量瓶颈。PR 旨在通过 Triton 调优和 kernel 优化缩小这些差距。

推荐精读。尤其注意 kernel dedup 设计方法和 BF16 精度分析。对于 Gemma4 部署有直接收益;对编写数值稳定的 Triton kernel 有参考价值。

讨论亮点

内核去重建议:审查者 yuan-luo 指出 _gemma_qkv_rmsnorm_by_head_kernel(2D grid)与既有 _gemma_qkv_rmsnorm_kernel(1D grid)实质相同,维护两份会同步困难。BBuf 接受建议并在后续 commit 中通过 BY_HEAD constexpr 合并为一个 kernel,验证了精度和性能等价。

Extend attention block size 自动调优:yuan-luo 提出硬编码阈值(128/256)应改为 autotune key。BBuf 同意并承诺作为后续改进,暂不阻塞当前性能提升。

H100/H20 配置缺失:yuan-luo 建议增加 H100 和 H20 的 MoE 配置。BBuf 表示在没有相应硬件验证的情况下不应复制 H200 配置,留作后续。

精度回归根本原因:PR body 详细分析了两个融合 kernel 导致精度下降的原因——BF16 累加顺序和 softmax reduction tree 差异使路由 logits 出现系统性偏移,累积后扩大了 MTP draft/target 分布差异。这一分析是重要的工程教训。

实现拆解

  1. MoE Triton 配置调优:在 triton_3_6_0 配置目录下新增两个 JSON 文件,分别为 normal 和 down 投影提供 H200 专用的 block size / warp / stage 调优参数(E=128, N=704)。这些配置覆盖从 1 到 1024 的各类 M 值,通过调整 BLOCK_SIZE_M/N/K、GROUP_SIZE_M、num_warps 和 num_stages 使 Triton 编译器输出更优的 GPU kernel。

  2. Extend-attention block size 分段调整:修改 extend_attention.py 中的 _get_block_sizes_for_extend_attention 函数,在 Hopper 架构(CUDA capability >= 9)下将 Lq ≤ 256 这一档拆分为 ≤128 和 128–256 两档,分别使用 (128,64) 和 (64,64) 的 (BLOCK_M, BLOCK_N) 组合,以匹配 Gemma4 典型请求长度,减少 prefill 阶段耗时。

  3. 小批量 QKV RMSNorm by-head kernel:在 gemma4_fused_ops.py 中新增辅助函数 _gemma_qkv_rmsnorm_store,封装单 head 的 load-rms-scale-store 逻辑;然后改造 _gemma_qkv_rmsnorm_kernel,通过编译时常量 BY_HEAD 支持两种 launch 模式:当 BY_HEAD=True 时 grid 为 (M, total_heads),每个 program 处理一个 token 的一个 head,适合小批量;BY_HEAD=False 时 grid 为 (M,),串行循环 heads,适合大批量。外层 Python 函数 gemma_qkv_rmsnorm 根据 M 或传入参数自动选择 launch 路径。该 kernel 原位归一化 Q、K、V,使用 strided 视图避免 contiguous 拷贝。

  4. (回滚)MoE routing 融合:原始提交包含两个融合优化:a) 将 Gemma4Router.forward 中的 rmsnorm + scale 合并为单步 fused kernel;b) 新增 _gemma4_topk_softmax_scale_kernel 将 topk、softmax、per-expert 缩放合并为一个 pass。由于 BF16 精度回归导致 MTP GSM8K 得分从 0.445 降至 0.360,两个 commit 被先后回滚。回滚后精度恢复,其余优化保留。

  5. 配套验证:无新增测试文件,但通过 CI 中的 test_gemma4_mtp_26b_a4b_extra 测试和自定义 benchmark 验证了性能提升和准确性等价。

文件 模块 状态 重要度
python/sglang/srt/layers/gemma4_fused_ops.py 融合算子 modified 7.77
python/sglang/srt/layers/moe/moe_runner/triton_utils/configs/triton_3_6_0/E=128,N=704,device_name=NVIDIA_H200.json MoE 配置 added 5.78
python/sglang/srt/layers/moe/moe_runner/triton_utils/configs/triton_3_6_0/E=128,N=704,device_name=NVIDIA_H200_down.json MoE 配置 added 5.78
python/sglang/srt/layers/attention/triton_ops/extend_attention.py 扩展注意力 modified 3.93

关键符号

_gemma_qkv_rmsnorm_store _gemma_qkv_rmsnorm_kernel gemma_qkv_rmsnorm _get_block_sizes_for_extend_attention

关键源码片段

python/sglang/srt/layers/gemma4_fused_ops.py core-logic

核心逻辑修改,新增辅助函数 `_gemma_qkv_rmsnorm_store`,改造 `_gemma_qkv_rmsnorm_kernel` 支持 BY_HEAD 双模式,实现 QKV RMSNorm 小批量加速与 kernel 去重。

# 辅助函数:单个 head 的 RMSNorm 归一化(带可选权重)
@triton.jit
def _gemma_qkv_rmsnorm_store(
    X_ptr, W_ptr, stride_m, m, h, cols, mask,
    HEAD_DIM: tl.constexpr, eps, HAS_WEIGHT: tl.constexpr,
):
    # 计算每个 head 的偏移量
    off = m * stride_m + h * HEAD_DIM + cols
    x = tl.load(X_ptr + off, mask=mask, other=0.0).to(tl.float32)
    # RMSNorm: x / sqrt(mean(x^2) + eps)
    rrms = tl.rsqrt(tl.sum(x * x, axis=0) / HEAD_DIM + eps)
    out = x * rrms
    # 若有权重,则乘以权重(权重长度为 head_dim,广播使用)
    if HAS_WEIGHT:
        w = tl.load(W_ptr + cols, mask=mask, other=0.0).to(tl.float32)
        out = out * w
    tl.store(X_ptr + off, out.to(X_ptr.dtype.element_ty), mask=mask)
​
​
@triton.jit
def _gemma_qkv_rmsnorm_kernel(
    Q_ptr, K_ptr, V_ptr, Q_w_ptr, K_w_ptr,
    stride_q_m, stride_k_m, stride_v_m,
    NUM_Q_HEADS: tl.constexpr, NUM_KV_HEADS: tl.constexpr,
    HEAD_DIM: tl.constexpr, eps, HAS_KV: tl.constexpr,
    BY_HEAD: tl.constexpr, BLOCK: tl.constexpr,
):
    # 通过 BY_HEAD 实现两种 launch 形状:
    # - True:grid (M, total_heads),每个 program 处理一个 token 的一个 head
    # - False:grid (M,),串行循环所有 head
    m = tl.program_id(0)
    cols = tl.arange(0, BLOCK)
    mask = cols < HEAD_DIM
​
    if BY_HEAD:
        # 每个 program 处理单个 head
        h_all = tl.program_id(1)
        if h_all < NUM_Q_HEADS:
            _gemma_qkv_rmsnorm_store(
                Q_ptr, Q_w_ptr, stride_q_m, m, h_all, cols, mask, HEAD_DIM, eps, True
            )
        elif HAS_KV and h_all < NUM_Q_HEADS + NUM_KV_HEADS:
            h = h_all - NUM_Q_HEADS
            _gemma_qkv_rmsnorm_store(
                K_ptr, K_w_ptr, stride_k_m, m, h, cols, mask, HEAD_DIM, eps, True
            )
        elif HAS_KV:
            h = h_all - NUM_Q_HEADS - NUM_KV_HEADS
            # V head 不使用权重(weight=1)
            _gemma_qkv_rmsnorm_store(
                V_ptr, Q_w_ptr, stride_v_m, m, h, cols, mask, HEAD_DIM, eps, False
            )
    else:
        # 串行循环:适用于大批量,避免过多 program
        for h in tl.static_range(NUM_Q_HEADS):
            _gemma_qkv_rmsnorm_store(
                Q_ptr, Q_w_ptr, stride_q_m, m, h, cols, mask, HEAD_DIM, eps, True
            )
        if HAS_KV:
            for h in tl.static_range(NUM_KV_HEADS):
                _gemma_qkv_rmsnorm_store(
                    K_ptr, K_w_ptr, stride_k_m, m, h, cols, mask, HEAD_DIM, eps, True
                )
            for h in tl.static_range(NUM_KV_HEADS):
                _gemma_qkv_rmsnorm_store(
                    V_ptr, Q_w_ptr, stride_v_m, m, h, cols, mask, HEAD_DIM, eps, False
                )

评论区精华

合并 QKV RMSNorm kernel (BY_HEAD) 设计

审查者 yuan-luo 指出 `_gemma_qkv_rmsnorm_by_head_kernel`(2D grid)与既有 `_gemma_qkv_rmsnorm_kernel`(1D grid)实质相同,维护两份会同步困难。

结论:BBuf 接受建议,在后续 commit 中通过 `BY_HEAD` constexpr 合并为一个 kernel,验证了精度和性能等价。 · 已解决

Extend attention block size 应 autotune 性能

yuan-luo 建议将硬编码阈值(128/256)改为 autotune keys。BBuf 同意但作为后续改进。

结论:作为后续任务,当前 PR 暂保留硬编码。 · resolved-followup

增加 H100/H20 MoE 配置 other

yuan-luo 建议增加 H100 和 H20 的 MoE 配置。BBuf 表示无对应硬件验证,暂不添加。

结论:暂不添加,留作后续。 · deferred

MoE 路由融合精度回归分析 正确性

PR body 和提交历史显示融合 topk-softmax-scale 和 fused router RMSNorm 导致 BF16 精度下降约 5pp GSM8K。作者通过二分法定位两个 commit 共同作用,分析了 BF16 累加顺序和 softmax reduction tree 差异。

结论:两个融合 kernel 被回滚,精度恢复。未来需要 bit-exact 实现才能启用。 · 已解决

关联 PR 合并努力 other

pyc96 在 issue 中提及相关 PR (#26502, #25461),建议合并优化。BBuf 表示会 review。

结论:未在本次 PR 中合并,仅作为后续参考。 · not-resolved

风险与影响

  1. 精度回归风险:路由融合优化在 BF16 下因浮点顺序不等价导致精度下降(已回滚)。若未来有人恢复或修改 fused kernel 而未严格 bit-wise 校验,可能再次引入。
  2. 核心路径变更:extend attention block size 的修改会影响所有 Hopper 架构上的 prefill 性能。尽管对 Gemma4 正向,但未在非 Gemma4 模型上验证,可能存在细微回归。
  3. H200 特有配置:新增 MoE 配置仅针对 H200 调优,在其他 Hopper 卡(如 H100, H20)上可能不是最优,但不会导致错误。
  4. 缺少直接单元测试:QKV RMSNorm 和 MoE 配置变更没有对应的 Python 单元测试,仅依赖 CI 中的端到端 MTP 测试覆盖。
  5. 配置可维护性:自动调优(Triton Autotune)尚未集成,手动维护 JSON 配置可能跟不上模型变化。

用户:Gemma4 用户将直接受益于更低的 TTFT(-23.3%)和 TPOT(-18.2%),以及约 18% 的端到端延迟改善。其他模型用户受影响极小(extend attention block 调整可能略有影响,但整体中性)。
系统:修改了 attention 和 MoE 两个关键计算路径,但改动范围小,无新增依赖。
团队:维护者需要理解两个 kernel 路径的共存;融合路由的教训值得团队其他优化项目参考。

精度回归风险 核心路径变更 缺少测试覆盖 H200 特有配置

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论