Prhub

#42453 [Feature] Support batch invariant rms norm with residual

原始 PR 作者 yewentao256 合并时间 2026-06-04 03:22 文件变更 3 提交数 2 评论 2 代码增减 +39 / -41

执行摘要

融合 residual 支持到 batch-invariant RMS norm

PR body 指出需要支持批量不变的RMS归一化与residual融合,以使RMSNorm(CustomOp)的代码更加清晰。作者声明'No functional change as we go into the same kernel path',纯粹是代码重构。

值得精读,特别是关于批处理不变性归一化的设计模式。合并函数并支持可选residual的做法简洁清晰,可作为类似重构的参考。

讨论亮点

主要的讨论来自gemini-code-assist[bot]的两条评论:

  • 建议在融合路径中也确保输入tensor是contiguous并reshape到2D,并将weight断言移到函数顶部以适用于两个分支。该建议未被采纳,作者认为当前实现已经正确。
  • 建议将vllm._custom_ops的导入移到模块顶部以避免热路径开销。该建议未被采纳。

实现拆解

  1. 合并函数定义:在vllm/model_executor/layers/batch_invariant.py中,删除原来的rms_norm函数(Triton kernel实现)和简单的包装函数rms_norm_batch_invariant,将两者合并为新的rms_norm_batch_invariant。新函数接受可选参数residual: torch.Tensor | None = None:当residual不为None时,直接调用ops.fused_add_rms_norm进行融合计算并返回(output, residual_out)元组;当residualNone时,执行原有的Triton RMS归一化逻辑。
  2. 更新调用点:在vllm/model_executor/layers/layernorm.pyRMSNorm.forward_cuda中,修改batch invariant分支的条件和参数:移除residual is None的守卫条件,传入residual=residual,并增加断言variance_size_override is None(因为该参数不支持批量不变模式)。这样当VLLM_BATCH_INVARIANT启用时,无论有无residual都会走统一的rms_norm_batch_invariant路径。
  3. 适配测试文件:在tests/v1/determinism/test_rms_norm_batch_invariant.py中,将导入语句从from ... import rms_norm as triton_rms_norm改为from ... import rms_norm_batch_invariant,所有测试用例中triton_rms_norm的调用替换为rms_norm_batch_invariant,保证测试继续有效。
文件 模块 状态 重要度
vllm/model_executor/layers/batch_invariant.py 模型执行器 modified 8.0
vllm/model_executor/layers/layernorm.py 模型执行器 modified 6.85
tests/v1/determinism/test_rms_norm_batch_invariant.py 测试 modified 5.47

关键符号

rms_norm_batch_invariant RMSNorm.forward_cuda

关键源码片段

vllm/model_executor/layers/batch_invariant.py data-contract

核心变更文件:删除 rms_norm 函数,将功能合并到 rms_norm_batch_invariant,新增 residual 参数和融合路径。

# vllm/model_executor/layers/batch_invariant.py
# 重构后的 rms_norm_batch_invariant 函数:
# - 当提供 residual 时,融合加法与归一化
# - 否则执行标准 Triton RMS 归一化def rms_norm_batch_invariant(
    input: torch.Tensor,
    weight: torch.Tensor,
    eps: float = 1e-6,
    residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """
    Compute RMS normalization using Triton kernel.
    When residual is provided, it fuses the add and norm step.
    """
    if residual is not None:
        # 融合路径:调用自定义 CUDA kernel 原地修改 input 和 residual
        assert input.shape == residual.shape, (
            f"Input shape {input.shape} must match residual shape {residual.shape}"
        )
        import vllm._custom_ops as ops
        ops.fused_add_rms_norm(input, residual, weight, eps)
        return input, residual
​
    # 标准 Triton RMS 归一化路径(无 residual)
    assert weight.dim() == 1, "Weight must be 1-dimensional"
    assert input.shape[-1] == weight.shape[0], (
        f"Input last dimension ({input.shape[-1]}) must match "
        f"weight dimension ({weight.shape[0]})"
    )
    original_shape = input.shape
    input_2d = input.reshape(-1, input.shape[-1]).contiguous()
    weight = weight.contiguous()
    n_rows, n_cols = input_2d.shape
    output = torch.empty_like(input_2d)
    BLOCK_SIZE = 1024
    grid = (n_rows,)
    _rms_norm_kernel[grid](
        input_2d, weight, output,
        input_2d.stride(0), output.stride(0),
        n_cols, eps, BLOCK_SIZE=BLOCK_SIZE,
    )
    return output.reshape(original_shape)
vllm/model_executor/layers/layernorm.py core-logic

修改 forward_cuda 中的 batch invariant 分支,移除 residual is None 的限制,新增 variance_size_override 断言,使调用更统一。

# vllm/model_executor/layers/layernorm.py 中的 RMSNorm.forward_cuda 方法
# 重构后的 batch invariant 分支:现在也处理有 residual 的情况def forward_cuda(
    self,
    x: torch.Tensor,
    residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    if envs.VLLM_BATCH_INVARIANT:
        # 批量不变模式不支持 variance_size_override
        assert self.variance_size_override is None, (
            "Batch invariance is not supported for variance_size_override"
        )
        # 直接调用统一的 rms_norm_batch_invariant,可以处理有 / 无 residual
        return rms_norm_batch_invariant(
            x,
            self.weight.data,
            self.variance_epsilon,
            residual=residual,
        )
    # 不走批量不变模式时,回退到原生实现
    return self.forward_native(x, residual)

评论区精华

确保融合路径中 tensor contiguous 及断言放置 正确性

gemini-code-assist[bot] 建议在融合路径中也确保 input/residual/weight 是 contiguous 并 reshape 到 2D,并将 weight 断言移至函数顶部。

结论:未采纳。作者认为当前实现正确,未修改。 · 已解决

将 import 移到模块顶层避免热路径开销 性能

gemini-code-assist[bot] 建议将 `import vllm._custom_ops` 移到文件顶部以避免重复查找开销。

结论:未采纳。作者未修改。 · 已解决

风险与影响

风险较低。变更不改变任何现有功能,所有测试通过(13 passed)。但需注意:

  • 若未来其他调用方直接使用rms_norm函数(已删除),会出现导入错误。但历史PR分析未发现此类使用。
  • VLLM_BATCH_INVARIANT为True且residual不为None时,forward_cuda现在走rms_norm_batch_invariant的融合路径,路径行为改变,但功能等价。

影响范围小,仅涉及3个文件。对用户无感知,系统行为完全一致。对团队来说,代码可读性提升,未来维护更容易。

函数名删除可能导致外部导入失败

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论