Prhub

#24360 [AMD] Replace naive triton RMSNorm with aiter RMSNorm for diffusion model

原始 PR 作者 yichiche 合并时间 2026-05-08 17:44 文件变更 1 提交数 1 评论 1 代码增减 +50 / -0

执行摘要

AMD 扩散模型 RMSNorm 替换为 aiter 内核,加速 30%

替换基于 Triton 的朴素 RMSNorm 实现为 aiter 的优化 HIP 内核,在 ROCm 扩散模型推理路径上获得约 30% 的 kernel 级加速(430 us → 290 us),尽管该 kernel 仅占去噪总时间的 0.2%,但属于低成本的单体替换改进。

值得合并,改动清晰且风险可控。建议作者后续添加单元测试以覆盖 forward_aiter 的各种输入情况(包括 fp32 回退、残差分支、非连续张量等),确保长期维护性。该 PR 展示了如何在后端利用专用算子库(aiter)替换通用实现,是良好的微优化案例。

讨论亮点

主要讨论来自 gemini-code-assist[bot] 的一条 review comment,建议使用已有的 _get_weight 辅助方法获取权重,而非手动转换 dtype,以保持与 MUSA 等其他后端的实现一致性。该建议已被作者采纳并体现在最终代码中(head_excerpt 中可见)。

实现拆解

  1. 导入与配置调整:在 python/sglang/multimodal_gen/runtime/layers/layernorm.py 中新增 _is_hip 标志和 _use_aiter 控制变量(通过 SGLANG_USE_AITER 环境变量启用),并在 __init__ 中根据 _use_aiter_forward_method 指向新增的 forward_aiter
  2. aiter 函数导入:当 _use_aiter 为 True 时,从 aiter 包导入 rmsnorm2d_fwdrmsnorm2d_fwd_with_add,分别作为 rms_normfused_add_rms_norm 使用。
  3. forward_aiter 方法实现:该方法首先检查输入 dtype(仅支持 fp16/bf16)和 variance_size_override,不满足则回退到 forward_native;随后将任意形状输入 reshape 为 2D((batch*seq_len, hidden_size))以适配 aiter 期望的布局;若 residual 存在则调用 fused_add_rms_norm(带残差融合),否则调用 rms_norm,最后将输出 reshape 回原始形状。
  4. 回退机制:基于 _is_hip_use_aiter 的组合,确保非 HIP 平台或未启用 aiter 时行为与原 forward_hip 一致。
文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/layers/layernorm.py 归一化层 modified 7.24

关键符号

forward_aiter

关键源码片段

python/sglang/multimodal_gen/runtime/layers/layernorm.py core-logic

核心修改文件,新增 `forward_aiter` 方法,调整 `__init__` 中的方法选择逻辑,导入 aiter 模块,新增环境变量控制开关。

# python/sglang/multimodal_gen/runtime/layers/layernorm.py
# 关键片段:forward_aiter 方法,支持残差融合回退和 FP32 回退def forward_aiter(
    self,
    x: torch.Tensor,
    residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """使用 aiter 库的 HIP 内核执行 RMSNorm,
    仅在 FP16/BF16 且无 variance_size_override 时调用,
    否则回退到 forward_native
    """
    # aiter CK kernel 只支持 FP16/BF16(输出 dtype 检查会拒绝 FP32)
    if (x.dtype not in (torch.float16, torch.bfloat16)
            or self.variance_size_override is not None):
        return self.forward_native(x, residual)
​
    # 利用已有 _get_weight 方法确保权重 dtype 与输入一致
    weight = self._get_weight(x.dtype)
​
    # 将任意形状输入 reshape 为 (batch*seq_len, hidden_size) 2D 张量
    shape = x.shape
    x_2d = x.reshape(-1, shape[-1])
    if not x_2d.is_contiguous():
        x_2d = x_2d.contiguous()
​
    if residual is not None:
        residual_shape = residual.shape
        residual_2d = residual.reshape(-1, shape[-1])
        if not residual_2d.is_contiguous():
            residual_2d = residual_2d.contiguous()
        output = torch.empty_like(x_2d)
        residual_out = torch.empty_like(x_2d)
        # fused_add_rms_norm: 同时计算 RMSNorm 和残差加法
        fused_add_rms_norm(output, x_2d, residual_2d, residual_out,
                           weight, self.variance_epsilon)
        return output.view(shape), residual_out.view(residual_shape)
​
    # 无残差分支:直接调用 aiter rmsnorm
    return rms_norm(x_2d, weight, self.variance_epsilon).view(shape)

评论区精华

权重 dtype 转换建议使用 _get_weight 辅助方法 设计

gemini-code-assist[bot] 建议使用已有的 `_get_weight` 方法替代手动 dtype 转换,以保持与其他后端(如 MUSA)的一致性。

结论:作者已采纳建议,最终代码中使用了 `self._get_weight(x.dtype)`。 · 已解决

风险与影响

低风险。变更局限在 forward_aiter 方法内,有明确的回退逻辑(fp32、非 HIP 平台、variance_size_override 非 None 时回退到 forward_native)。主要风险是 aiter 包在某些 ROCm 版本上不可用或 API 不兼容,但通过 SGLANG_USE_AITER 环境变量开关可隔离。未新增测试文件,可能缺少回归覆盖。

影响范围小:仅修改一个源文件 layernorm.py,仅影响 AMD ROCm 平台且启用了 SGLANG_USE_AITER 的扩散模型推理路径。对 CUDA、NPU、MUSA 等其他后端无影响。对去噪阶段整体加速约 0.2%,但对 RMSNorm 单个算子加速明显。

缺少测试覆盖 核心路径变更 环境变量控制

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论