Prhub

#40408 [Perf] Batch invariance with Cutlass fp8 support, 28.9% E2E latency improvement

原始 PR 作者 yewentao256 合并时间 2026-05-12 00:20 文件变更 11 提交数 12 评论 5 代码增减 +305 / -3

执行摘要

使用 Cutlass FP8 实现批量不变性,延迟降低 28.9%

避免 FP8 量化/反量化开销,提升 batch invariant 模式下的推理性能。PR body 明确指出“Use cutlass fp8 to avoid quantize/dequantize overhead”。

该 PR 值得精读,尤其关注:1)如何通过固定 CUTLASS 配置实现 batch invariance 并保持正确性;2)FP8 线性层 apply 的分支设计兼顾性能与回退。对使用 FP8 批处理推理的团队有直接影响。

讨论亮点
  • Batch invariant 属性的质疑tlrmchlsmth 指出 CutlassFP8ScaledMMLinearKernel 的内核行为可能随未来调优而失去 batch invariance 属性。作者 yewentao256 回应已附上测试确保任意 M 下输出正确,且 CI 会运行这些单元测试。
  • 注释表述更新ElizaWszola 询问注释中从“prefer DeepGEMM”改为“prefer direct FP8”是否意味着 block 版本也优先 direct FP8。作者解释“direct FP8”包含了 DeepGEMM。
  • CUTLASS dispatch 注释需求tlrmchlsmth 建议在 sm100 和 sm120 dispatch 中添加注释,明确说明配置需要独立于 M 的原因(batch invariance)。作者确认已添加。

实现拆解

  1. 新增 CUTLASS batch invariant 内核:在 csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/ 下的 scaled_mm_sm{90,100,120}_fp8_dispatch.cuhscaled_mm_c2x_sm89_fp8_dispatch.cuh 中新增 cutlass_gemm_*_batch_invariant_dispatchcutlass_scaled_mm_*_batch_invariant_epilogue 函数。这些函数使用固定的 CUTLASS 配置(M=64),不随实际 M 变化,从而保证输出与 batch size 无关。
  2. 修改 Python 量化层:在 vllm/model_executor/layers/quantization/fp8.pyonline/fp8.pyapply 方法中添加分支——当 VLLM_BATCH_INVARIANT 启用且 self.fp8_linearCutlassFP8ScaledMMLinearKernel 实例时,直接调用 self.fp8_linear.apply_weights(layer, x, bias),跳过原有的 BF16 反量化路径。注释也相应更新。
  3. 更新 .cu 文件绑定:在 scaled_mm_sm{90,100,120}_fp8.cuscaled_mm_c2x.cu 中导入并调用新的 batch_invariant_epilogue 函数,使其对 Python 可见。
  4. 新增测试:添加 tests/v1/determinism/test_cutlass_batch_invariance.py,使用 TestFP8Layer 强制 CutlassFP8ScaledMMLinearKernel,验证在不同 batch size 和 weight shape 下,特定 token 的输出严格与 batch size 无关(assert_close 要求 rtol=0, atol=0)。
文件 模块 状态 重要度
tests/v1/determinism/test_cutlass_batch_invariance.py 批量测试 added 6.85
vllm/model_executor/layers/quantization/fp8.py 量化层 modified 6.27
vllm/model_executor/layers/quantization/online/fp8.py 量化层 modified 5.94
csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh CUTLASS 内核 modified 4.47
csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh CUTLASS 内核 modified 4.46
csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh CUTLASS 内核 modified 4.34
csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_fp8_dispatch.cuh CUTLASS 内核 modified 4.3
csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu CUTLASS 绑定 modified 3.78
csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu CUTLASS 绑定 modified 3.78
csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu CUTLASS 绑定 modified 3.5
csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu CUTLASS 绑定 modified 3.5

关键符号

Fp8LinearMethod.apply _Fp8OnlineLinearBase.apply cutlass_gemm_sm90_fp8_batch_invariant_dispatch cutlass_scaled_mm_sm90_fp8_batch_invariant_epilogue test_cutlass_fp8_batch_invariant_fixed_config

关键源码片段

tests/v1/determinism/test_cutlass_batch_invariance.py test-coverage

新增测试文件,验证 Cutlass FP8 内核在不同 batch size 和 weight shape 下的 batch invariance 性质,确保功能正确性。

# tests/v1/determinism/test_cutlass_batch_invariance.pyimport pytest
import torch
import vllm.envs as envs
from tests.utils import TestFP8Layer, requires_fp8
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import CutlassFP8ScaledMMLinearKernel
from vllm.model_executor.layers.quantization.utils.quant_utils import kFp8DynamicTokenSym, kFp8StaticTensorSym
from vllm.platforms import current_platformpytest.importorskip("torch.cuda")@pytest.fixture(autouse=True)
def setup_cuda():
    if not current_platform.is_cuda():
        pytest.skip("CUTLASS FP8 kernels require CUDA.")
    torch.set_default_device("cuda")@requires_fp8
@pytest.mark.parametrize("weight_shape", [(1024, 2048), (4608, 4096)])
@pytest.mark.parametrize("batch_size", [1, 16, 17, 32, 64, 65, 256, 257])
@torch.inference_mode()
def test_cutlass_fp8_batch_invariant_fixed_config(
    weight_shape: tuple[int, int],
    batch_size: int,
    default_vllm_config,
    monkeypatch: pytest.MonkeyPatch,
):
    # 启用 batch invariant 环境变量
    monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
    monkeypatch.setattr(envs, "VLLM_BATCH_INVARIANT", True)
​
    torch.manual_seed(0)
    # 构造 FP8 层,强制使用 CutlassFP8ScaledMMLinearKernel
    layer = TestFP8Layer(
        weight_shape=weight_shape,
        activation_quant_key=kFp8DynamicTokenSym,
        weight_quant_key=kFp8StaticTensorSym,
        input_dtype=torch.bfloat16,
        out_dtype=torch.bfloat16,
        device=torch.device("cuda"),
        force_kernel=CutlassFP8ScaledMMLinearKernel,
    )
    assert isinstance(layer.kernel, CutlassFP8ScaledMMLinearKernel)
​
    in_features = weight_shape[1]
    # 创建一个 needle token,作为检测 batch invariant 的锚点
    needle = torch.randn((1, in_features), device="cuda", dtype=torch.bfloat16)
    baseline = layer(needle)[0] # 单个 token 的输出作为基准
​
    # 创建 filler 用来组装不同 batch size
    filler = torch.randn(
        (max(batch_size - 1, 0), in_features), device="cuda", dtype=torch.bfloat16
    )
​
    # 将 needle 放在 batch 的最前面和最后面
    front_batch = torch.cat([needle, filler], dim=0)
    back_batch = torch.cat([filler, needle], dim=0)
​
    front_output = layer(front_batch)[0]
    back_output = layer(back_batch)[-1]
​
    # 严格校验:无论 batch size 和 needle 位置,输出与 baseline 一致
    torch.testing.assert_close(front_output, baseline, rtol=0, atol=0)
    torch.testing.assert_close(back_output, baseline, rtol=0, atol=0)
vllm/model_executor/layers/quantization/fp8.py data-contract

核心量化层文件,修改了 `apply` 方法以支持 Cutlass FP8 batch invariant 路径,并更新了导入。

# vllm/model_executor/layers/quantization/fp8.py (partial)import torch
import vllm.envs as envs
from vllm.model_executor.kernels.linear.scaled_mm import (
    CutlassFP8ScaledMMLinearKernel,
    MarlinFP8ScaledMMLinearKernel,
)# ... ( 类定义 )def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: torch.Tensor | None = None,
) -> torch.Tensor:
    # 当启用 VLLM_BATCH_INVARIANT 时,优先使用直接 FP8 路径
    # 如果底层内核是 CutlassFP8ScaledMMLinearKernel 且非 block 量化,
    # 则直接调用 apply_weights,避免 BF16 反量化开销
    if envs.VLLM_BATCH_INVARIANT:
        if self.block_quant:
            assert self.weight_block_size is not None
            return self.fp8_linear.apply_weights(layer, x, bias)
        else:
            # 新分支:直接使用 Cutlass FP8 计算
            if isinstance(self.fp8_linear, CutlassFP8ScaledMMLinearKernel):
                return self.fp8_linear.apply_weights(layer, x, bias)
​
            # 反量化回 BF16 并执行 GEMM (fallback)
            weight_fp8 = layer.weight.to(torch.bfloat16)
            weight_scale = layer.weight_scale.to(torch.bfloat16)
            if weight_scale.numel() == 1:
                weight_bf16 = weight_fp8 * weight_scale
            else:
                # 多 scale 处理 ( 如 QKV 融合 )
                if weight_scale.dim() == 1 and weight_scale.shape[0] == weight_fp8.shape[0]:
                    weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                else:
                    weight_bf16 = weight_fp8 * weight_scale
            return torch.nn.functional.linear(x, weight_bf16.t(), bias)
​
    # 非 batch invariant 模式:使用 Marlin 或默认的 FP8 scaled GEMM
    if self.use_marlin:
        return self.fp8_linear.apply_weights(layer, x, bias)
    return self.fp8_linear.apply_weights(layer, x, bias)
csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh core-logic

新增 SM90 专用的 batch invariant dispatch 和 epilogue 函数,是 CUTLASS 内核的核心变更。

// csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh
// (partial)template <typename InType, typename OutType, bool EnableBias,
          typename... EpilogueArgs>
inline void cutlass_gemm_sm90_fp8_batch_invariant_dispatch(
    torch::stable::Tensor& out, torch::stable::Tensor const& a,
    torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
    torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
  // 该 dispatch 使用固定 CUTLASS 配置,不依赖于 M(batch size)
  // 确保 batch invariance:输出与 M 无关
  static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  // 检查张量类型为 FP8 e4m3
  STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn);
  STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn);  // 根据 N 维大小选择不同的 CUTLASS 配置
  using Cutlass3xGemmM64_N1280 = typename sm90_fp8_config_M64_N1280<InType, OutType, EnableBias>::Cutlass3xGemm;
  using Cutlass3xGemmM64_N8192 = typename sm90_fp8_config_M64_N8192<InType, OutType, EnableBias>::Cutlass3xGemm;  uint32_t const n = b.size(1); // 输出特征维度
  if (n <= 1280) {
    return cutlass_gemm_caller_sm90_fp8<Cutlass3xGemmM64_N1280>(
        out, a, b, b_scales, a_scales, std::forward<EpilogueArgs>(args)...);
  }
  return cutlass_gemm_caller_sm90_fp8<Cutlass3xGemmM64_N8192>(
      out, a, b, b_scales, a_scales, std::forward<EpilogueArgs>(args)...);
}template <bool EnableBias, typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_fp8_batch_invariant_epilogue(
    torch::stable::Tensor& out, torch::stable::Tensor const& a,
    torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
    torch::stable::Tensor const& b_scales, EpilogueArgs&&... epilogue_args) {
  // 检查输入类型为 FP8
  STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn);
  STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn);  // 根据输出数据类型选择 bf16 或 half 的 dispatch
  if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
    return cutlass_gemm_sm90_fp8_batch_invariant_dispatch<cutlass::float_e4m3_t, cutlass::bfloat16_t, EnableBias>(
        out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(epilogue_args)...);
  } else {
    STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
    return cutlass_gemm_sm90_fp8_batch_invariant_dispatch<cutlass::float_e4m3_t, cutlass::half_t, EnableBias>(
        out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(epilogue_args)...);
  }
}

评论区精华

Batch invariant 属性的质疑与确认 设计

tlrmchlsmth 表达对 CutlassFP8ScaledMMLinearKernel batch invariant 属性的担忧,认为若未来调优可能不再保持。

结论:yewentao256 回应已附上测试确保任意 M 下正确,且 CI 会运行这些单元测试,属性暂时可靠。 · 已解决

注释中 DeepGEMM 的表述更新 style

ElizaWszola 询问为何将注释从 'prefer DeepGEMM FP8 path' 改为 'prefer direct FP8 path',是否 block 版本也优先 direct FP8。

结论:yewentao256 解释 'direct FP8' 包含 DeepGEMM,注释已准确更新。 · 已解决

CUTLASS dispatch 注释需求 documentation

tlrmchlsmth 建议在 sm100 和 sm120 dispatch 函数中添加注释,说明配置需要独立于 M 的原因(batch invariance)。

结论:yewentao256 确认已按要求添加注释。 · 已解决

风险与影响

  • 批次无关性假设风险:当前实现假设固定配置(M=64)在所有情况下都保证输出独立于 M,但若未来对 CUTLASS 内核进行算术优化改变计算结果顺序,可能打破该假设。测试覆盖的 shape 有限(仅两种 weight_shape 和六个 batch size),可能遗漏边界情况。
  • 性能风险:固定 M 配置在部分大 batch 场景下可能非最优,但 batch invariance 要求必须固定,此权衡可接受。
  • 维护同步成本:新增的 batch_invariant_dispatch 系列函数与原始 dispatch 重复较多代码,后续主线内核升级时需要同步更新,增加维护负担。
  • 平台覆盖:SM89/90/100/120 均得到支持,但缺少对其他 GPU 架构(如 SM80)的 fallback 测试,若在未支持的架构上误用可能导致运行时错误。
  • 用户影响:启用 VLLM_BATCH_INVARIANT=1 后,使用 FP8 量化且支持 Cutlass 的模型获得显著性能提升(延迟降低约 29%),且输出不再依赖 batch size,简化了批处理调优。若使用 Marlin 或其他量化 kernel 则行为不变。
  • 系统影响:新增约 300 行 C++ 和 Python 代码,主要影响量化层和 CUTLASS 内核调度。编译时间略有增加。
  • 团队影响:需要维护 batch invariant 内核配置与主内核同步,确保两者行为一致。
批次无关性假设 CUTLASS 配置覆盖 测试 shape 有限

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论