Prhub

#26430 Fix GemmaRMSNorm gemma_weight buffer storage for Qwen3.5

原始 PR 作者 guapisolo 合并时间 2026-05-28 18:42 文件变更 1 提交数 1 评论 4 代码增减 +5 / -2

执行摘要

修复 GemmaRMSNorm buffer 存储导致 CUDA Graph 失效

来自 PR body:运行 Qwen3.5 时,首次权重更新后 logp 高达约 3.0。根本原因是 #22673 将 weight 输入改为 self.gemma_weight,但 weight_loader 中 self.gemma_weight = param.data + 1.0 会分配新内存,导致 CUDA Graph 仍使用旧权重。

这是一个值得精读的微型实例:演示了 PyTorch 中 = 赋值与原地操作在 CUDA Graph 上下文下的关键区别。团队可借鉴此模式审查其他存在 buffer = expr 赋值且参与 CUDA Graph 捕获的模块。

讨论亮点

PR 无 review 评论,仅有两位 maintainer 的快速 approval。说明变更逻辑清晰、风险可控,社区共识高。

实现拆解

  1. 修改初始化:在 __init__ 中将 gemma_weight buffer 的初始化从 self.weight.data + 1.0(创建新 tensor)改为 torch.ones_like(self.weight)(分配持久 buffer)。这样 buffer 的存储地址在模型生命周期内固定不变。
  2. 原地更新权重:在 _weight_loader 中将 self.gemma_weight = param.data + 1.0(重新赋值,新内存)替换为 torch.add(param.data, 1.0, out=self.gemma_weight)(原地加法,写入同一 buffer)。确保 CUDA Graph 捕获的 buffer 引用始终有效。
  3. 不变之处_forward_impl 和其余前向逻辑保持不变,因为 _forward_impl 实际使用 self.weight.data 而非 self.gemma_weight,避免了符号歧义。
文件 模块 状态 重要度
python/sglang/srt/layers/layernorm.py 层归一化 modified 5.63

关键符号

GemmaRMSNorm.__init__ GemmaRMSNorm._weight_loader

关键源码片段

python/sglang/srt/layers/layernorm.py core-logic

唯一修改的文件,包含 GemmaRMSNorm 类的 `__init__` 和 `_weight_loader` 方法,修复了 buffer 存储不稳定导致 CUDA Graph 失效的问题。

# python/sglang/srt/layers/layernorm.py
# GemmaRMSNorm 类核心变更:保持 buffer 内存地址固定,以兼容 CUDA Graphclass GemmaRMSNorm(MultiPlatformOp):
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps
        # 改动 1:用 torch.ones_like 创建持久 buffer
        # 原代码 self.register_buffer("gemma_weight", self.weight.data + 1.0, ...)
        # 会分配新内存,导致后续赋值时丢弃原 buffer 的固定地址
        self.register_buffer(
            "gemma_weight", torch.ones_like(self.weight), persistent=False
        )
        self.weight.weight_loader = self._weight_loader
​
    def _weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
        assert param.size() == loaded_weight.size()
        param.data.copy_(loaded_weight)
        # 改动 2:使用 torch.add 原地更新,确保 CUDA Graph 捕获的 buffer 引用指向同一内存
        # 原代码 self.gemma_weight = param.data + 1.0 会破坏存储稳定性
        torch.add(param.data, 1.0, out=self.gemma_weight)
​
    # _forward_impl 使用 self.weight.data 而非 self.gemma_weight,因此无需修改
    def _forward_impl(self, x, residual=None, post_residual_addition=None):
        # ... ( 保持不变 )
        out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
        return out

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

  1. CUDA Graph 兼容性:修改确保 buffer 存储固定,与 CUDA Graph 捕获兼容,风险已消除。
  2. 数值正确性torch.add 原地操作与 new_tensor = a + 1.0 结果一致,无精度差异。
  3. 回归风险:仅改动 GemmaRMSNorm 内部实现,不涉及其他模块,回归概率低。但缺少对应测试用例,未来修改可能覆盖不到。

影响范围集中:仅修复使用 GemmaRMSNorm 且权重可变的模型(如 Qwen3.5、Gemma 系列)。用户无需感知代码变更,但推理正确性得到保障。CUDA Graph 在权重更新后能正确捕获新权重,提升模型微调/持续学习场景的稳定性。

缺少测试覆盖 CUDA Graph 兼容性问题

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论