Prhub

#22931 [Fix/Kernel] Add JIT rmsnorm_hf kernel to fix transformers backend MMLU accuracy regression

原始 PR 作者 Jiminator 合并时间 2026-04-23 12:00 文件变更 5 提交数 10 评论 18 代码增减 +510 / -0

执行摘要

添加 JIT rmsnorm_hf 内核,修复 transformers 后端 MMLU 准确性回归并减少性能损失。

Commit 34ddf135fdTransformersBase 中引入 recursive_replace(),将 HF *RMSNorm 模块替换为 SGLang 的 sgl_kernel.rmsnorm。两个内核舍入顺序不同:sgl_kernel.rmsnorm 在 fp32 中乘权重然后转换,HF LlamaRMSNorm 先转换归一化输出到激活数据类型再乘权重。在 int4wo-128 量化下,约 1-ULP 差异在 32 层累积,翻转 3 个边界 MMLU 问题(0.703 → 0.656)。先前 naive 修复回滚更改导致显著性能回归,本 PR 通过添加 JIT 内核 rmsnorm_hf 恢复准确性并减少性能损失。

建议精读,重点关注内核设计中的舍入顺序处理(cast-before-weight-multiply)、性能优化策略(Warp/CTA 内核选择)以及测试中的回归防护方法(test_rmsnorm_hf_matches_hf_not_sgl),这些对处理类似量化或精度问题有借鉴价值。

讨论亮点
  • 性能基准:BBuf 要求提供内核级性能数据,Jiminator 回应了基准表格,显示新内核在多数场景下接近或优于基线,但大 token 数时可能变慢。
  • 测试覆盖:BBuf 指出正确性测试容差过松,可能无法捕获回归,Jiminator 添加了 test_rmsnorm_hf_matches_hf_not_sgl 测试,通过比较与 HF 和旧 SGL 参考的距离来确保内核遵循 HF 语义。
  • 内核设计:BBuf 建议优化内核布局,Jiminator 扩展了支持范围,添加 Warp 内核以覆盖小 hidden size(如 q/k RMSNorm 的 128),避免性能回退。
  • API 硬化:BBuf 提出输入维度验证和空输入处理,Jiminator 在包装器中添加了 input.dim() != 2 检查和空输入短路逻辑。
  • 集成防护:BBuf 指出 batch-invariant 模式可能绕过 HF 语义,Jiminator 修改逻辑使其回退到 forward_native;同时添加 self.weight.data.dtype == x.dtype 防护以确保数据类型匹配。

实现拆解

  1. 新增 CUDA 内核:在 python/sglang/jit_kernel/csrc/elementwise/rmsnorm_hf.cuh 中实现两个内核变体:Warp 内核(支持 hidden size 为 32 倍数且在 [32, 512) 范围)和 CTA 内核(支持 hidden size 为 512 倍数且 ≥512),均采用 HF 语义(先转换归一化 x 到 dtype 再乘权重),通过寄存器缓存减少全局内存读取。
  2. 新增 Python 包装器:在 python/sglang/jit_kernel/rmsnorm_hf.py 中定义 rmsnorm_hf 函数,使用 cache_onceload_jit 加载 JIT 模块,添加输入验证(2D 张量、fp16/bf16 类型、支持 hidden size 检查)和空输入短路处理。
  3. 修改层调度逻辑:在 python/sglang/srt/layers/layernorm.py 中,当 cast_x_before_out_mul=True 且无残差时,调度到 rmsnorm_hf 内核(条件包括 CUDA 可用、数据类型匹配、hidden size 支持),否则回退到 forward_native;同时修复 batch-invariant 模式下的防护。
  4. 启用 HF 语义:在 python/sglang/srt/models/transformers.pyreplace_rms_norm_class 中为非 Gemma 路径设置 cast_x_before_out_mul=True,确保 transformers 后端使用 HF 舍入顺序。
  5. 添加测试覆盖:在 python/sglang/jit_kernel/tests/test_rmsnorm_hf.py 中新增单元测试,包括正确性测试(对比 HF 参考)、输出参数测试、回归防护测试(确保内核更接近 HF 而非旧 SGL 语义)、空输入测试和 hidden size 支持测试,并注册到 CI 套件。
文件 模块 状态 重要度
python/sglang/jit_kernel/rmsnorm_hf.py JIT 内核 added 8.27
python/sglang/jit_kernel/tests/test_rmsnorm_hf.py 测试套件 added 7.55
python/sglang/srt/layers/layernorm.py 层归一化 modified 7.04
python/sglang/jit_kernel/csrc/elementwise/rmsnorm_hf.cuh CUDA 内核 added 6.02
python/sglang/srt/models/transformers.py 模型后端 modified 5.03

关键符号

is_supported_rmsnorm_hf_hidden_size _jit_rmsnorm_hf_module rmsnorm_hf hf_rmsnorm_reference sgl_rmsnorm_reference

关键源码片段

python/sglang/jit_kernel/rmsnorm_hf.py core-logic

新增 Python 包装器,是 JIT 内核的主要入口点,定义了 `rmsnorm_hf` 函数和 hidden size 支持检查。

def rmsnorm_hf(
    input: torch.Tensor,
    weight: torch.Tensor,
    eps: float = 1e-6,
    out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """RMSNorm: `out = weight * cast_dtype(rsqrt(mean(x^2) + eps) * x)`.    遵循 HF 语义:先将归一化的 x 转换到激活 dtype,再乘权重。
    输入必须是 2D `(num_tokens, hidden_size)`;hidden_size 需满足 `is_supported_rmsnorm_hf_hidden_size`。
    空输入直接返回空输出,避免内核启动。
    """
    if input.dtype not in (torch.float16, torch.bfloat16):
        raise RuntimeError(f"rmsnorm_hf: input must be fp16 or bf16, got {input.dtype}")
    if input.dim() != 2:
        raise RuntimeError(f"rmsnorm_hf: input must be 2D, got {input.dim()}D")
    hidden_size = input.size(-1)
    if not is_supported_rmsnorm_hf_hidden_size(hidden_size):
        raise RuntimeError(
            f"rmsnorm_hf: unsupported hidden_size={hidden_size} "
            f"(must be a multiple of 32 in [32, 512) or a multiple of 512)"
        )
    if out is None:
        out = torch.empty_like(input)
    if input.numel() == 0:
        return out # 短路处理,避免内核启动错误
    module = _jit_rmsnorm_hf_module(hidden_size, input.dtype) # 缓存 JIT 模块
    module.rmsnorm_hf(input, weight, out, eps) # 调用底层 CUDA 内核
    return out
python/sglang/srt/layers/layernorm.py core-logic

修改现有 RMSNorm 类的 forward_cuda 方法,集成新内核调度逻辑,是功能启用的核心。

if self.cast_x_before_out_mul and residual is None:
    # 使用 HF 语义内核(先转换到 dtype 再乘权重)。
    if (
        _jit_rmsnorm_hf_available
        and x.dtype in (torch.float16, torch.bfloat16)
        and self.weight.data.dtype == x.dtype
        and is_supported_rmsnorm_hf_hidden_size(x.shape[-1])
    ):
        out = _jit_rmsnorm_hf(
            x.contiguous(), self.weight.data, self.variance_epsilon
        ) # 调用 JIT 内核
    else:
        # 回退:纯 Python HF 语义(已在 forward_native 中实现)。
        out = self.forward_native(x, None, None)
    if needs_reshape:
        out = out.reshape(original_shape)
    return out

评论区精华

内核性能基准与设计优化 性能

BBuf 要求提供内核级性能数据,并建议优化内核布局;Jiminator 回应了基准表格,显示新内核在多数场景下接近基线,但大 token 数时变慢,并添加了 Warp 内核以覆盖小 hidden size。

结论:Jiminator 扩展了内核支持范围,添加 Warp 内核,避免了性能回退,并提供了详细基准数据。 · 已解决

测试覆盖与回归防护 测试

BBuf 指出正确性测试容差过松,可能无法捕获回归;Jiminator 添加了 `test_rmsnorm_hf_matches_hf_not_sgl` 测试,通过比较距离确保内核遵循 HF 语义。

结论:新增回归防护测试,强化了测试套件,能有效捕获舍入顺序错误。 · 已解决

API 硬化与输入验证 正确性

BBuf 提出输入维度验证和空输入处理;Jiminator 在包装器中添加了 `input.dim() != 2` 检查和空输入短路逻辑。

结论:API 得到硬化,提高了鲁棒性,避免无效输入导致错误。 · 已解决

风险与影响

  1. 性能回归风险:新内核在特定 hidden size 和 token 数下可能比优化后的向量化基线慢(如基准显示 16384 hidden size 时 1.25x 延迟),需监控生产负载。
  2. 回归测试不足:尽管添加了回归防护测试,但容差设置(atol=1e-2)可能仍允许微小数值漂移,在极端量化场景下需验证。
  3. 兼容性风险:内核仅支持 hidden size 为 32 或 512 倍数,且仅限 fp16/bf16,不支持的配置将回退到较慢的 forward_native,可能影响非标准模型。
  4. 集成复杂性:调度逻辑在 layernorm.py 中新增条件分支,增加了代码复杂度,可能引入错误,尤其在残差处理或 batch-invariant 模式下。

对用户:修复了 transformers 后端在 MMLU 基准测试中的准确性回归(从 0.656 恢复到 0.703),提升模型输出质量,同时吞吐量回归从显著(868.824 → 780.230 tok/s)减少到可接受(845.62 tok/s)。对系统:新增 JIT 内核扩展了内核库,为未来类似精度问题提供模板;调度逻辑调整可能轻微增加运行时开销。对团队:展示了在数值精度和性能间权衡的设计模式,强化了测试中的回归防护实践。

核心路径变更 性能回归风险 测试覆盖不足

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论