Prhub

#38093 [Bugfix] Fix scaled_mm output narrowing for 3D input tensors

原始 PR 作者 nemanjaudovic 合并时间 2026-04-20 16:58 文件变更 1 提交数 25 评论 5 代码增减 +15 / -4

执行摘要

修复 FP8 scaled_mm 内核中 3D 输入张量输出缩小错误,避免引擎初始化崩溃。

在初始化Phi-3.5-vision-instruct-FP8-dynamic模型时,引擎崩溃并报错'shape mismatch',原因是输出缩小步骤使用了output_shape[0](仅批次维度)而不是完整令牌数(批次乘以序列长度)。PR body中描述了具体崩溃场景和修复需求,即修复scaled_mm输出缩小对3D输入张量的不正确处理。

建议工程师精读此PR,关注_get_num_tokens辅助函数的设计,它展示了处理张量形状展平时的通用模式。此外,review讨论中的DRY原则实践值得借鉴,有助于提高代码质量。

讨论亮点

Review中,gemini-code-assist[bot]指出num_tokens计算逻辑在三个类中重复,建议提取为模块级辅助函数以遵循DRY原则。nemanjaudovic响应并添加了_get_num_tokens函数,tjtanaa随后批准。讨论焦点是代码可维护性改进,无未解决疑虑。

实现拆解

  1. 导入模块并添加辅助函数:在vllm/model_executor/kernels/linear/scaled_mm/pytorch.py文件中添加import math,并定义_get_num_tokens(output_shape: list) -> int函数,计算除最后一个维度外的所有维度乘积。
  2. 更新内核变体中的apply_scaled_mm方法:在TorchFP8ScaledMMLinearKernelRowWiseTorchFP8ScaledMMLinearKernelChannelWiseTorchFP8ScaledMMLinearKernel类中,用num_tokens = _get_num_tokens(output_shape)替换硬编码的output_shape[0],确保输出缩小步骤正确处理任意维度输入。
  3. 测试验证:PR body中提及通过运行初始化脚本和预提交钩子验证修复。虽然没有新增测试文件,但修复基于实际崩溃场景验证,并确保前向传播正确。
文件 模块 状态 重要度
vllm/model_executor/kernels/linear/scaled_mm/pytorch.py 线性内核 modified 6.73

关键符号

_get_num_tokens TorchFP8ScaledMMLinearKernel.apply_scaled_mm RowWiseTorchFP8ScaledMMLinearKernel.apply_scaled_mm ChannelWiseTorchFP8ScaledMMLinearKernel.apply_scaled_mm

关键源码片段

vllm/model_executor/kernels/linear/scaled_mm/pytorch.py core-logic

这是唯一修改的文件,包含所有 FP8 scaled_mm torch 内核的实现,修复了输出形状计算的核心逻辑,直接影响模型初始化和前向传播。

import math # 新增导入,用于计算乘积def _get_num_tokens(output_shape: list) -> int:
    """
    计算输出形状中的令牌数。
    torch._scaled_mm 处理 2D 张量,输入张量如果是 3D 会被展平。
    如果 output_shape 是 3D,令牌数是除最后一个维度(隐藏维度)外所有维度的乘积。
    """
    return math.prod(output_shape[:-1])# 示例:在 TorchFP8ScaledMMLinearKernel 的 apply_scaled_mm 方法中使用
class TorchFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
    def apply_scaled_mm(
        self,
        *,
        A: torch.Tensor,
        B: torch.Tensor,
        out_dtype: torch.dtype,
        As: torch.Tensor,
        Bs: torch.Tensor,
        bias: torch.Tensor | None,
        output_shape: list,
    ) -> torch.Tensor:
        output = torch._scaled_mm(
            A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
        )
        # 处理 torch 版本差异,确保输出是单个张量
        if type(output) is tuple and len(output) == 2:
            output = output[0]
​
        num_tokens = _get_num_tokens(output_shape) # 使用辅助函数获取正确令牌数
        return torch.narrow(output, 0, 0, num_tokens).view(*output_shape) # 正确缩小并重塑形状

评论区精华

代码重复与辅助函数提取 设计

gemini-code-assist[bot] 指出 num_tokens 计算逻辑在三个地方重复,建议提取为模块级辅助函数以遵循 DRY 原则。

结论:nemanjaudovic 添加了 _get_num_tokens 函数,解决了重复问题,代码更易维护。 · 已解决

风险与影响

风险较低:变更仅限于单个文件中的形状计算逻辑,影响范围明确。潜在风险是如果output_shape不符合预期(如少于两个维度),_get_num_tokens可能抛出错误,但鉴于调用上下文,这种情况应已由基类处理。无性能或安全风险,兼容性良好。

对用户:修复了使用3D输入张量的FP8量化模型(如Phi-3.5)的初始化崩溃问题,提升稳定性和可用性。对系统:确保FP8线性内核在多种输入形状下正常工作,支持更广泛的模型配置和平台(如ROCm)。对团队:代码更简洁,减少重复逻辑,便于未来维护和扩展。

核心路径变更

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论