Prhub

#25920 [bugfix] Honor cast_x_before_out_mul in RMSNorm.forward_cuda residual path

原始 PR 作者 charlotte12l 合并时间 2026-05-28 16:22 文件变更 4 提交数 9 评论 18 代码增减 +108 / -29

执行摘要

修复 RMSNorm 残差路径忽略 cast_x_before_out_mul 标志

cast_x_before_out_mul=True 且提供残差时,RMSNorm 输出的数学语义与配置的 HF 语义不一致。对于标准残差流 Transformer,从第 1 层开始,每一层都会产生错误的数值(如 PR body 所述)。此 bug 影响多个显式设置该标志并传入残差的模型(如 Qwen2/3、SDAR、MOSs-VL 等),导致最终输出偏离预期。

这是一个高质量 bugfix,修复了影响核心正确性的问题,且设计迭代清晰——从临时 fallback 到独立 kernel 再到合并到现有 kernel。值得精读:展示了如何在 CUDA kernel 中通过 if constexpr 实现多语义路径,以及如何平衡数值精度与性能。建议相关模型维护者关注黄金测试是否需要调整。

讨论亮点
  • 设计取舍:DarkSharpness 建议不要创建单独文件(fused_add_rmsnorm_hf.cuh),而是通过 constexpr flag 重用现有 kernel。最终采纳此方案,将 HF 语义作为模板参数 kCastXBeforeOutMul 融入现有 CUDA kernel,保持了代码库的简洁。
  • 精度争议:charlotte12l 在实现中引入 inp_res_cache 来缓存 fp32 的 sum,以通过 bitwise 测试。DarkSharpness 质疑是否需要 cache。charlotte12l 解释:HF 语义的参考实现(forward_native)使用 fp32 求和,如果不缓存,从 v[i](已 round 到 bf16)反推会导致偏差超出测试容忍度(1e-2)。最终保留 cache,确保与原生实现严格匹配。
  • 接受与改进:最终方案获得 BBuf 和 DarkSharpness 的认可,DarkSharpness 给出了 LGTM。

实现拆解

  1. 添加 hidden_size 兼容检查norm.py):新增 is_supported_jit_fused_add_rmsnorm_hidden_size 函数,限制 JIT kernel 适用的 hidden_size(>0, %16==0, <=8192),确保 kernel 在目标架构上安全运行。
  2. JIT kernel 改良norm.py + fused_add_rmsnorm.cuh):修改 _jit_fused_add_rmsnorm_module,将 cast_x_before_out_mul 作为编译参数传入 CUDA kernel。CUDA kernel 增加 kCastXBeforeOutMul 模板参数:当为 true 时,pass 1 缓存 fp32 的 input+residual 和(避免后续从 bf16 回读时精度损失),pass 2 先对 sum * rsqrt 结果下取整到窄类型,再与 weight 相乘(实现 HF 的 cast-before-multiply 语义)。
  3. 运行时调度layernorm.py):修改 RMSNorm.forward_cuda,当 residual is not Nonecast_x_before_out_mul=True 时,新增分发逻辑:检查 dtype 和 hidden_size 兼容后调用 JIT kernel,否则 fallback 到 forward_native(已正确实现 HF 语义)。
  4. 测试覆盖test_fused_add_rmsnorm.py):新增 forward_native_hf_reference 参考实现(纯 Python 的 fp32 求和与 cast-before-multiply),参数化测试 cast_x_before_out_mul 的 False/True 分支。宽松 tol(1e-2)验证全部 BS×hidden 组合,严格 bitwise 验证单 shape 的 forward_native 等价。
  5. 配套修复:修复了 CI 套件命名、isort 排序等问题,并附带了一次 rustfmt 清理(service_discovery.rs)。
文件 模块 状态 重要度
python/sglang/jit_kernel/norm.py JIT 内核 modified 6.91
python/sglang/srt/layers/layernorm.py 运行时调度 modified 6.68
python/sglang/jit_kernel/tests/test_fused_add_rmsnorm.py 单元测试 modified 6.09
python/sglang/jit_kernel/csrc/elementwise/fused_add_rmsnorm.cuh CUDA 内核 modified 4.96

关键符号

is_supported_jit_fused_add_rmsnorm_hidden_size _jit_fused_add_rmsnorm_module fused_add_rmsnorm RMSNorm.forward_cuda forward_native_hf_reference FusedAddRMSNormKernel::run

关键源码片段

python/sglang/jit_kernel/norm.py core-logic

核心修改:新增隐藏尺寸兼容检查函数,修改 JIT module 加载以传递 `cast_x_before_out_mul` 参数,修改 `fused_add_rmsnorm` 函数暴露该 flag。

# 检查 fused add rmsnorm JIT kernel 是否支持给定 hidden_size
def is_supported_jit_fused_add_rmsnorm_hidden_size(hidden_size: int) -> bool:
    # 要求 hidden_size > 0、能被 16 整除(对齐要求)、且不超过 8192
    return hidden_size > 0 and hidden_size % 16 == 0 and hidden_size <= 8192
​
​
@cache_once
def _jit_fused_add_rmsnorm_module(
    dtype: torch.dtype, cast_x_before_out_mul: bool # 新增参数,控制 HF 语义
) -> Module:
    # 将 `cast_x_before_out_mul` 作为编译参数传递给 CUDA kernel
    args = make_cpp_args(cast_x_before_out_mul, dtype)
    return load_jit(
        "fused_add_rmsnorm",
        *args,
        cuda_files=["elementwise/fused_add_rmsnorm.cuh"],
        cuda_wrappers=[("fused_add_rmsnorm", f"FusedAddRMSNormKernel<{args}>::run")],
    )
​
​
@debug_kernel_api
def fused_add_rmsnorm(
    input: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    eps: float = 1e-6,
    *,
    cast_x_before_out_mul: bool = False, # 添加仅关键字参数,默认为 False 保持向后兼容
) -> None:
    # 根据传入的 flag 选择对应语义的 JIT module
    module = _jit_fused_add_rmsnorm_module(input.dtype, cast_x_before_out_mul)
    module.fused_add_rmsnorm(input, residual, weight, eps)
python/sglang/srt/layers/layernorm.py dependency-wiring

运行时调度入口:在 `forward_cuda` 的残差路径中优先使用 JIT kernel,否则 fallback 到 `forward_native`。并导入新增的 JIT 函数。

if residual is not None:
    if self.cast_x_before_out_mul:
        # 检查 JIT kernel 的要求:dtype 为 fp16/bf16,weight 与 x 同 dtype,
        # 且 post_residual_addition 如果存在也同 dtype,hidden_size 受支持
        if (
            x.dtype in (torch.float16, torch.bfloat16)
            and self.weight.data.dtype == x.dtype
            and (
                post_residual_addition is None
                or post_residual_addition.dtype == x.dtype
            )
            and is_supported_jit_fused_add_rmsnorm_hidden_size(x.shape[-1])
        ):
            # 先处理 3 路求和:将 post_residual_addition 加入 residual 中(fp32 由 kernel 内部处理)
            if post_residual_addition is not None:
                residual = residual + post_residual_addition
            # 调用 JIT kernel,传递 cast_x_before_out_mul 标志
            _jit_fused_add_rmsnorm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
                cast_x_before_out_mul=self.cast_x_before_out_mul,
            )
            return x, residual
        # 条件不满足时 fallback 到 forward_native(已正确实现 HF 语义)
        return self.forward_native(x, residual, post_residual_addition)
python/sglang/jit_kernel/csrc/elementwise/fused_add_rmsnorm.cuh core-logic

CUDA kernel 核心修改:添加 `kCastXBeforeOutMul` 模板参数,在 pass 2 中实现 cast-before-multiply 语义,并通过 `inp_res_cache` 缓存 fp32 sum 以保持数值等效。

// 当 kCastXBeforeOutMul 为 true 时执行 HF 语义:先对(input+residual)* rsqrt 的结果 cast 到窄类型,再乘 weight
// valf 是 fp32 的 input+residual 和(来自 inp_res_cache 或 v[i] 的 fp32 转换)
template <bool kCastXBeforeOutMul, typename packed_t>
SGL_DEVICE packed_t rms(float2 valf, packed_t& weight, float rsqrt_square_sum) {
    float2 weightf = device::cast<fp32x2_t, packed_t>(weight);
    if constexpr (kCastXBeforeOutMul) {
        // HF 语义:将 (sum * rsqrt) 结果先 cast 回窄类型(如 bf16),再转回 fp32 与 weight 相乘
        auto rounded = device::cast<packed_t, fp32x2_t>(
            make_float2(valf.x * rsqrt_square_sum, valf.y * rsqrt_square_sum));
        valf = device::cast<fp32x2_t, packed_t>(rounded);
        return device::cast<packed_t, fp32x2_t>(
            make_float2(valf.x * weightf.x, valf.y * weightf.y));
    }
    // 默认语义:直接乘 weight 再乘 rsqrt,所有运算在 fp32 中完成
    return device::cast<packed_t, fp32x2_t>(
        make_float2(valf.x * weightf.x * rsqrt_square_sum,
                    valf.y * weightf.y * rsqrt_square_sum));
}// … 在 pass 1 中,若 kCastXBeforeOutMul 为 true,将 fp32 的 input+residual 和缓存到 inp_res_cache
if constexpr (kCastXBeforeOutMul) {
    inp_res_cache[i] = inp_res; // inp_res 是 fp32 的 x+residual
}// pass 2 中,从 v[i](已 round 到 DType)或 inp_res_cache(fp32)读取 sum
float2 valf;
if constexpr (kCastXBeforeOutMul) {
    valf = inp_res_cache[i]; // 使用 fp32 精度的 sum,与 forward_native 一致
} else {
    valf = device::cast<fp32x2_t, packed_t>(v[i]);
}
v_out[i] = rms<kCastXBeforeOutMul>(valf, v_weight[i], rsqrt_square_sum);

评论区精华

代码设计:复用现有 kernel vs 创建单独文件 设计

DarkSharpness 建议不要创建 `fused_add_rmsnorm_hf.cuh` 单独文件,而是通过 constexpr flag 重用现有 `fused_add_rmsnorm.cuh` 内核。charlotte12l 提出三种方案,最终采用方案 1(直接在现有 JIT kernel 中加 flag)。

结论:采纳 DarkSharpness 建议,将 HF 语义作为 `kCastXBeforeOutMul` 模板参数融入现有 kernel。 · 已解决

精度问题:是否需要 fp32 sum cache 性能

DarkSharpness 询问为什么需要 `inp_res_cache`。charlotte12l 解释:HF 语义的参考实现(forward_native)使用 fp32 求和,如果不缓存,从 bf16 回读会导致精度损失超出测试 tol,因此保留 cache 确保 bitwise 等效。

结论:保留 `inp_res_cache`,因为它对于通过 bitwise 测试是必要的,且仅当 `kCastXBeforeOutMul` 为 true 时启用,无额外开销。 · 已解决

测试策略:宽松与 bitwise 测试 测试

charlotte12l 在测试中同时使用宽松 tol(1e-2)覆盖多种 shape 和严格 bitwise 测试覆盖单 shape,以确保数值正确性。

结论:测试设计被接受,`forward_native_hf_reference` 作为参考实现。 · 已解决

风险与影响

  1. 数值回退风险:非 HF 路径(kCastXBeforeOutMul=false)因 if constexpr 保证零开销,无需担忧性能退化。但所有调用点都需要重新编译 JIT kernel(首次运行时)。
  2. hidden_size 兼容性is_supported_jit_fused_add_rmsnorm_hidden_size 限定了 %16 对齐和 <=8192,超出范围的模型会安全降级到 forward_native。若未来需要支持更大 hidden_size,需扩展内核。
  3. 精度变化:之前因 bug 而使用非 HF 语义的模型,输出会略微变化,更贴近参考实现。依赖固定 golden(如 CI 缓存)的测试可能需要刷新 tolerance。PR body 列出了 7 个受影响模型文件。
  4. 性能影响:从 forward_native(纯 Python)切换到 JIT kernel 通常期望正向收益;对于已使用 flashinfer 内核的路径无影响。
  • 用户/模型:所有设置 cast_x_before_out_mul=True 并传入残差的模型(sdar.py、sdar_moe.py、moss_vl.py、qwen2.py、qwen3.py、transformers.py、vision.py)自动获得正确的数值输出,无需代码更改。
  • 系统:JIT kernel 缓存机制不变;新增模板参数后,相同参数组合只编译一次,无重复开销。
  • 团队:后续维护者在扩展 fused_add_rmsnorm 内核时需留意 kCastXBeforeOutMul 的一致性;测试框架已提供参考更易于验证。
核心路径变更 影响多模型精度 需刷新 CI golden JIT kernel 重新编译

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论