Prhub

#23974 [AMD] Fix Aiter RMSNorm layout handling

原始 PR 作者 hubertlu-tw 合并时间 2026-04-29 10:28 文件变更 1 提交数 2 评论 1 代码增减 +13 / -1

执行摘要

修复 Aiter RMSNorm 在 strided 高维输入下的内存访问越界

Qwen3 等模型在 MI325/gfx942 上通过 Aiter 后端运行时,Q/K 投影切片产生 strided 3D 视图(如 shape=(2048,32,128), stride=(6144,128,1)),直接传入 Aiter 的 RMSNorm 内核导致 GPU 内存访问越界崩溃。PR body 明确描述了根因和复现步骤。

值得合入。PR 定位准确,修复方案最小且安全,对 AMD 用户有实际价值。建议阅读 forward_aiter() 的实现方式,可作为处理不同后端 kernel 约束的参考模式。

讨论亮点

无 review 评论。PR 由 HaiShaw 直接 approve,无明显争议。

实现拆解

  1. 布局检测:在 forward_aiter() 入口处增加 needs_reshape = x.dim() != 2 and residual is None 判断,识别非 2D 且无 residual 的输入。
  2. 布局规范化:若 needs_reshape,则记录原形状 original_shape = x.shape,执行 x.contiguous().reshape(-1, original_shape[-1]) 将输入展平为 2D 连续张量;若为 2D 但不连续,则仅做 x.contiguous()
  3. 计算结果恢复:在无 residual 分支末尾,若 needs_reshape,则将 output reshape 回 original_shape
  4. 保留零拷贝路径:2D 连续输入不走任何拷贝,保持原性能。

变更仅涉及 python/sglang/srt/layers/layernorm.py 中的 forward_aiter() 方法(+13/-1),无其他文件改动。

文件 模块 状态 重要度
python/sglang/srt/layers/layernorm.py 运行层 modified 6.02

关键符号

forward_aiter

关键源码片段

python/sglang/srt/layers/layernorm.py core-logic

实现了 Aiter RMSNorm 布局规范化修复,是本次 PR 唯一修改的文件。

def forward_aiter(
    self,
    x: torch.Tensor,
    residual: Optional[torch.Tensor] = None,
    post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    # Aiter 的 RMSNorm 内核只支持 2D contiguous 输入。
    # 对于已安全的 2D 连续布局保持零拷贝路径,仅规范化 strided 或高秩视图
    # 如从 packed QKV 投影中切出的 Q/K 切片 (shape=(2048,32,128), stride=(6144,128,1))
    needs_reshape = x.dim() != 2 and residual is None
    if needs_reshape:
        original_shape = x.shape
        # 转为连续后展平为 (batch*heads, dim) 的 2D 张量
        x = x.contiguous().reshape(-1, original_shape[-1])
    elif not x.is_contiguous():
        # 2D 但不连续时仅做连续化(可能由转置等引起)
        x = x.contiguous()
​
    if residual is not None:
        residual_out = torch.empty_like(x)
        output = torch.empty_like(x)
        if post_residual_addition is not None:
            residual = residual + post_residual_addition
        fused_add_rms_norm(
            output,
            x,
            residual,
            residual_out,
            self.weight.data,
            self.variance_epsilon,
        )
        return output, residual_out
​
    # 无 residual 路径:计算 RMSNorm
    output = rms_norm(x, self.weight.data, self.variance_epsilon)
    if needs_reshape:
        # 恢复原始高秩形状,保证后续计算正确
        output = output.reshape(original_shape)
    return output

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

  1. 回归风险低:仅修改无 residual 的 Aiter RMSNorm 路径,且逻辑与现有 forward_hip() 中的 contiguous 处理类似。
  2. 性能影响:仅在对非 2D 输入时引入一次拷贝+reshape,该路径在正常模型中不常见(仅 Q/K 切片等场景触发),对整体推理吞吐影响可忽略。
  3. 未覆盖 residual 路径needs_reshape 条件排除了 residual is not None 的情况,若未来有非 2D 输入带 residual 的场景,仍可能崩溃。但当前代码中 residual 路径通常用于 pre-norm 残差连接,输入多为 2D。
  1. 用户影响:修复了 AMD ROCm/Aiter 后端上 Qwen3 等模型的崩溃问题,用户可正常使用 Aiter RMSNorm 加速。
  2. 系统影响:仅影响 forward_aiter() 路径,不影响其他后端(如 HIP、NPU、MUSA)或 RMSNorm 的其他模式。
  3. 团队影响:修复代码简洁,易于维护。PR 建议后续在 Aiter 上游加固 API,但本次作为 caller 侧 guard 已足够。
缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论