Prhub

#42379 [Bugfix] Fix RMSNorm kernels to multiply in weight's native dtype

原始 PR 作者 liulanze 合并时间 2026-05-30 14:16 文件变更 2 提交数 2 评论 17 代码增减 +10 / -23

执行摘要

修复 RMSNorm 内核权重 dtype 精度回归

修复 #42325 报告的 RMSNorm 内核精度回归。PR body 指出该回归由 #40860 引入,导致模型输出严重偏差:'20-layer cumulative divergence exceeds 100, and only ~2% of output tokens match the reference'。继续使用异常内核将导致所有依赖 RMSNorm 的模型(尤其是 DeepSeek)产生错误结果。

建议所有用户升级此修复。对于内核贡献者,本 PR 提供了一个重要的数值精度决策案例:在编写 CUDA kernel 时,必须始终与 Python 前端的 dtype 规范保持一致,即使 FP32 直观上更精确,也要考虑累积误差。值得关注的设计决策:拒绝 'FP32 总是更好' 的假设,通过实验证据证明原生 dtype 的正确性。

讨论亮点
  • lm_eval 指标要求:审查者 @yewentao256 要求提供 lm_eval 指标确认无退化。作者提供 TinyLlama 结果,所有任务得分与 baseline 一致,验证了修复的安全性。
  • 设计争议:@zyongye 在 issue 评论中提出,非 Tensor Core 操作应在 FP32 上运行以最大化精度,并建议修改 Python IR 侧而非 CUDA。作者回应指出 Python 规范明确规定乘法的权重 dtype,且 FP32 累积误差在多层 RMSNorm 下不可接受(20 层累计超过 1.0)。最终决定维持原生 dtype 行为不变。
  • CI 失败判定:部分 CI 任务(如 fusedmoe 测试、MI300 multimodal 测试)出现失败,经 @AndreasKaratzas 确认均与硬件环境或无关配置有关,非本 PR 导致。PR 被强制合并。

实现拆解

  1. 修改 csrc/libtorch_stable/layernorm_kernels.cu:在 rms_norm_kernelfused_add_rms_norm_kernel(向量化路径和标量回退路径)共三处,将 x * s_variance * static_cast<float>(weight) 改为 (static_cast<scalar_t>(x * s_variance)) * weight,使得乘法在权重原生 dtype 下执行。

  2. 同步修改 csrc/libtorch_stable/layernorm_quant_kernels.cu:在 rms_norm_static_fp8_quant_kernelfused_add_rms_norm_static_fp8_quant_kernel 的对应三处,进行相同调整,并删除之前用于匹配非融合路径的临时舍入注释。确保量化融合内核与非融合复合路径(rms_norm → static_scaled_fp8_quant)的结果一致。

  3. 验证:通过 tests/kernels/core/test_layernorm.py(865 通过)和 tests/kernels/ir/test_layernorm.py(1442 通过,361 跳过)的全部测试。在 A100-SXM4-40GB 上运行回归复现脚本,确认修复前后 max diff 从 3.125e-02 降至 0.000e+00。使用 lm_eval 在 TinyLlama-1.1B 上评估 arc_challengehellaswagmmlu 等任务,得分完全一致,证明无精度退化。

文件 模块 状态 重要度
csrc/libtorch_stable/layernorm_kernels.cu CUDA 内核 modified 4.03
csrc/libtorch_stable/layernorm_quant_kernels.cu CUDA 内核 modified 4.16

关键符号

rms_norm_kernel fused_add_rms_norm_kernel rms_norm_static_fp8_quant_kernel fused_add_rms_norm_static_fp8_quant_kernel

关键源码片段

csrc/libtorch_stable/layernorm_kernels.cu core-logic

核心 RMSNorm kernel 逻辑更改,修复回归的主文件。

// csrc/libtorch_stable/layernorm_kernels.cu
// rms_norm_kernel 向量化路径 (VEC_SIZE=4)
// 计算 variance 后,每个线程处理 VEC_SIZE 个元素
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
    float x = static_cast<float>(src1.val[j]);
    // 先将规范化结果 x * s_variance 缩放到 scalar_t (BF16/FP16),
    // 再乘以权重的原生 dtype,确保与 Python spec `x.to(weight.dtype) * weight` 一致
    dst.val[j] = static_cast<scalar_t>(x * s_variance) * src2.val[j];
}

评论区精华

要求提供 lm_eval 指标验证无精度退化 测试

审查者 @yewentao256 请求添加 `lm_eval` 指标以证明 e2e 准确度无退化。作者提供了基于 TinyLlama-1.1B 的对比结果,显示 `arc_challenge`, `hellaswag`, `mmlu` 等任务得分完全一致。

结论:审查者确认无精度退化后批准 PR。 · 已解决

争议:FP32 更优精度 vs 匹配 Python spec 设计

核心贡献者 @zyongye 认为非 Tensor Core 操作应在 FP32 上以最大化精度,并建议修改 Python IR 侧(`layernorm.py`)而非 CUDA。作者 @liulanze 引用 Python spec 并指出 PF32 累积误差严重(20 层 RMSNorm 累计误差超 1.0),导致 token 匹配率仅 2%。其他审查者如 @yewentao256 和 @AndreasKaratzas 支持修复。

结论:维持原生 dtype 行为,拒绝 FP32 提议。PR 合并且 IR 端保持不变。 · 已解决

风险与影响

本 PR 修改了核心 CUDA kernel 的数值精度策略。主要风险包括:

  • 数值精度回归:虽然本 PR 旨在修复精度,但改变了 FP32 路径的舍入行为。然而,由于修改后与 Python spec 一致,且测试和 lm_eval 通过,风险很低。
  • FP8 量化路径兼容性:量化内核的同步修改可能改变 E4M3 量化边界上的行为,但作者通过测试验证了融合路径与非融合路径一致,且删除了之前为匹配而添加的舍入注释,风险可控。
  • 回溯兼容:修复后,依赖之前 FP32 行为的模型(如果存在)可能会产生不同结果。考虑到模型训练通常在算子原 dtype 下进行,预期不会有退化。风险低。

影响范围:所有使用 RMSNorm 的模型(几乎所有 transformer 模型),特别是 DeepSeek、Qwen、LLaDA 等使用 Q/K 层归一化的模型。修复后,输出 token 匹配率从 2% 恢复到正常水平。对 FP8 量化模型,融合量化内核与非融合路径结果一致,修复了潜在测试失败。对用户:升级后模型行为更准确,无需配置变更。对团队:此修复消除了一个严重回归问题,减少了用户投诉和调试负担。影响程度:高。对系统:代码变更仅 33 行,主要集中在 CUDA kernel,不涉及其它组件。影响范围小但影响深度大。

回归修复 数值精度敏感 CUDA 核心变更

关联 Issue

#40860 [Feat] DeepSeek V4 Rebased
#42325 [Bug]: RMSNorm kernel ignores weight dtype, always uses FP32 (regression in v0.20.0)

完整报告

参与讨论