Prhub

#41300 [DeepSeek] Use torch.mm for bf16xbf16->fp32 gemm

原始 PR 作者 WoosukKwon 合并时间 2026-05-01 07:28 文件变更 8 提交数 2 评论 3 代码增减 +9 / -91

执行摘要

删除自定义 bf16→fp32 GEMM,改用 torch.mm

发现 torch.mm 原生支持 bf16×bf16→fp32 矩阵乘法(底层调用 nvjet_sm100_tss_16x64_64x16_4x1_v_bz_TNN),无需再通过自定义 cuBLAS 包装实现,从而简化代码并减少维护成本。PR body 中通过脚本验证了该行为。

该 PR 是清理自定义算子的好示例,展示了如何利用 PyTorch 原生功能替代手写 CUDA 扩展。对于希望减少自定义代码依赖的开发者有参考价值。建议验证环境中的 PyTorch 版本是否支持 torch.mm(..., out_dtype=...)。整体风险可控,可合入。

讨论亮点

gemini-code-assist[bot] 对 torch.mmout_dtype 参数提出了兼容性质疑,指出该参数并非标准 PyTorch 2.5 及之前版本的公共 API,可能导致 TypeError;同时提醒 torch.mm 仅支持 2D 输入,若 hidden_states 为 3D 将引发运行时错误。然而,该 PR 最终被 tlrmchlsmth 批准,表明团队内部已确认所需 PyTorch 版本的兼容性(可能依赖 PyTorch 2.6+),或该路径中的输入形状已保证为 2D。

实现拆解

  1. 删除自定义 Python 算子:在 vllm/_custom_ops.py 中移除 router_gemm_bf16_fp32 函数及其 fake 注册 router_gemm_bf16_fp32_fake
  2. 删除中间包装函数:在 vllm/model_executor/layers/utils.py 中删除 cublas_gemm_bf16_bf16_fp32,该函数仅委托给 ops.router_gemm_bf16_fp32
  3. 替换调用点:在 vllm/model_executor/layers/deepseek_v4_attention.pyattn_gemm_parallel_execute 中,将 cublas_gemm_bf16_bf16_fp32 替换为 torch.mm 并指定 out_dtype=torch.float32;在 vllm/model_executor/layers/fused_moe/router/gate_linear.pyforward Tier 2 中,将 ops.router_gemm_bf16_fp32 替换为 torch.mm
  4. 删除 C++ 实现:移除 csrc/moe/router_gemm.cu(自定义 cuBLAS GEMM 实现)、csrc/moe/moe_ops.h 中的声明以及 csrc/moe/torch_bindings.cpp 中的绑定,并从 CMakeLists.txt 中移除对应的编译源。
  5. 无测试配套变更:本次改动未添加或修改测试,依赖现有测试覆盖。
文件 模块 状态 重要度
vllm/_custom_ops.py 自定义算子 modified 7.05
vllm/model_executor/layers/utils.py 工具层 modified 6.36
vllm/model_executor/layers/deepseek_v4_attention.py 注意力层 modified 6.3
csrc/moe/router_gemm.cu 路由内核 removed 5.21
vllm/model_executor/layers/fused_moe/router/gate_linear.py 门控线性层 modified 5.17

关键符号

router_gemm_bf16_fp32 router_gemm_bf16_fp32_fake cublas_gemm_bf16_bf16_fp32 attn_gemm_parallel_execute GateLinear.forward

关键源码片段

vllm/model_executor/layers/deepseek_v4_attention.py data-contract

两个关键调用点从 `cublas_gemm_bf16_bf16_fp32` 改为 `torch.mm`,影响注意力压缩器和索引器路径。

def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
    assert self.aux_stream_list is not None
    assert len(self.aux_stream_list) >= 3
​
    # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs
    # on aux streams 0..2 when their owning module exists.
    aux_fns: list[Callable[[], Any] | None] = [None, None, None]
​
    if self.compressor is not None:
        compressor = self.compressor
​
        def compressor_kv_score() -> torch.Tensor:
            # 使用 torch.mm 实现 bf16×bf16→fp32 GEMM
            return torch.mm(
                hidden_states,
                compressor.fused_wkv_wgate.weight.T,
                out_dtype=torch.float32,
            )
​
        aux_fns[0] = compressor_kv_score
​
    if self.indexer is not None:
        indexer = self.indexer
​
        def indexer_weights_proj() -> torch.Tensor:
            weights, _ = indexer.weights_proj(hidden_states)
            return weights
​
        def indexer_compressor_kv_score() -> torch.Tensor:
            return torch.mm(
                hidden_states,
                indexer.compressor.fused_wkv_wgate.weight.T,
                out_dtype=torch.float32,
            )
​
        aux_fns[1] = indexer_weights_proj
        aux_fns[2] = indexer_compressor_kv_score
​
    def fused_wqa_wkv() -> torch.Tensor:
        qr_kv, _ = self.fused_wqa_wkv(hidden_states)
        return qr_kv
​
    qr_kv, (kv_score, indexer_weights, indexer_kv_score) = execute_in_parallel(
        fused_wqa_wkv,
        aux_fns,
        self.ln_events[0],
        self.ln_events[1:4],
        self.aux_stream_list[:3],
    )
​
    return qr_kv, kv_score, indexer_kv_score, indexer_weights
vllm/model_executor/layers/fused_moe/router/gate_linear.py data-contract

在 MoE 路由前向传播的 Tier 2 中替换了 GEMM 调用,影响路由逻辑。

def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
    import vllm._custom_ops as ops
​
    # Tier 1: DSV3 specialized kernel
    if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
        output = ops.dsv3_router_gemm(
            hidden_states=x,
            router_weight=self.weight,
            output_dtype=self.out_dtype,
        )
        return output, None
​
    # Tier 2: cuBLAS bf16→fp32 (now using native torch.mm with out_dtype)
    if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16:
        output = torch.mm(x, self.weight.T, out_dtype=torch.float32)
        return output, None
​
    # Tier 3: F.linear (ReplicatedLinear)
    if self.out_dtype is not None and x.dtype != self.weight.dtype:
        x = x.to(self.weight.dtype)
    output, output_bias = super().forward(x)
    if self.out_dtype is not None and output.dtype != self.out_dtype:
        output = output.to(self.out_dtype)
    return output, output_bias

评论区精华

torch.mm out_dtype 参数兼容性及输入维度要求 正确性

gemini-code-assist[bot] 指出 `out_dtype` 不是标准 PyTorch 2.5 及之前版本的 API,可能导致 `TypeError`;且 `torch.mm` 仅支持 2D 输入,若输入为 3D 会出错。

结论:PR 被 tlrmchlsmth 批准,暗示团队已确认环境满足 PyTorch 版本要求且输入形状为 2D,或后续将处理兼容性。 · ACKNOWLEDGED

风险与影响

  1. PyTorch 版本兼容性torch.mmout_dtype 参数在 PyTorch 2.5 及以下版本不存在,使用旧版 PyTorch 的用户会遇到 TypeError。vLLM 通常依赖较新版本的 PyTorch(≥2.6),但若实际环境中版本不满足,需添加 fallback 或版本检查。
  2. 输入形状假设torch.mm 严格要求 2D 输入,若 hidden_states 在推理时可能以 3D shape(如 [batch, seq, hidden])传入,则会失败。当前调用点处于合并了 batch 和 seq 维度的上下文中(注意力层输入已铺平),风险较低,但仍需确认。
  3. 回归风险:删除自定义 cuBLAS 包装可能影响性能或精度,但 PyTorch 原生实现应同样基于 cuBLAS,且 out_dtype 保证了 fp32 精度,回归可能性较小。

对用户:仅影响 DeepSeek V4/V3 模型的推理路径,需使用支持 out_dtype 的 PyTorch 版本。对开发团队:消除了一个自定义 cuBLAS 包装,降低了 C++ 代码的编译和维护负担。对系统:编译时间略有减少,二进制体积缩小。

PyTorch out_dtype 兼容性 输入 2D 维度假设 无测试变更

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论