Prhub

#41664 [MXFP4] Support for linear layers + compressed-tensors integration

原始 PR 作者 dsikka 合并时间 2026-05-12 19:49 文件变更 11 提交数 11 评论 8 代码增减 +358 / -26

执行摘要

MXFP4 W4A4 线性层支持,集成 FlashInfer/Marlin 内核

为支持MXFP4量化,需要统一的内核抽象层,使不同硬件后端(FlashInfer CUTLASS on Blackwell和Marlin on其他平台)通过相同接口工作;同时将compressed-tensors方案从W4A16更新为W4A4以反映激活量化能力。

值得精读此PR。重点可关注MxFp4LinearKernel抽象类设计和init_mxfp4_linear_kernel工厂函数的多后端选择模式,以及如何通过环境变量VLLM_MXFP4_USE_MARLIN覆盖内核选择。compressed-tensors方案的重构方式(从直接调用Marlin到委托内核)也为其他量化格式统一提供了参考。此外,swizzle reshape的讨论展示了GPU编程中数据布局对齐的常见陷阱。

讨论亮点

主要讨论集中在FlashInfer内核中权重尺度swizzle后的reshape尺寸问题:

  • @yewentao256指出swizzle_mxfp4_scales内部已padding,但外部仍使用原始N进行reshape,可能导致形状不匹配和运行时错误,建议使用padded_N。
  • 作者@dsikka回复“已解决”,并在后续commit中将reshape改为使用padded_N。
  • @yewentao256还建议使用更精确的导入has_flashinfer_cutedsl替代has_flashinfer,作者采纳。
  • @gemini-code-assist[bot]也识别了相同的reshape风险。整体反馈良好,最终获得approve。

实现拆解

  1. 定义抽象内核基类:在vllm/model_executor/kernels/linear/mxfp4/base.py创建MxFp4LinearLayerConfig数据类和MxFp4LinearKernel抽象类,声明is_supportedcan_implementprocess_weights_after_loadingapply_weights四个抽象接口。
  2. 实现FlashInfer内核:新增vllm/model_executor/kernels/linear/mxfp4/flashinfer.py,实现FlashInferMxFp4LinearKernel。仅当计算能力≥sm_100且安装flashinfer_cutedsl时支持。process_weights_after_loading对权重尺度进行swizzle(填充N至128的倍数);apply_weights先将输入量化为MXFP4,然后调用flashinfer_scaled_fp4_mm进行W4A4 GEMM。
  3. 实现Marlin内核:新增vllm/model_executor/kernels/linear/mxfp4/marlin.py,实现MarlinMxFp4LinearKernel。复用已有的marlin_utils_fp4中的prepare_fp4_layer_for_marlinapply_fp4_marlin_linear,作为非Blackwell平台的回退(W4A16)。
  4. 内核选择工厂:在vllm/model_executor/kernels/linear/__init__.py添加init_mxfp4_linear_kernel函数。根据平台(当前仅CUDA)和环境变量VLLM_MXFP4_USE_MARLIN(强制使用Marlin),遍历_POSSIBLE_MXFP4_KERNELS列表,选择第一个支持的内核实例化。同时扩展_POSSIBLE_MXFP4_KERNELS字典及register_linear_kernel函数以支持mxfp4类型。
  5. 重构compressed-tensors方案:将compressed_tensors_w4a16_mxfp4.py重命名为compressed_tensors_w4a4_mxfp4.py,类名从CompressedTensorsW4A16Mxfp4改为CompressedTensorsW4A4Mxfp4。构造函数中调用init_mxfp4_linear_kernel获取内核实例,然后process_weights_after_loadingapply_weights全部委托给该内核,不再直接与Marlin耦合。
  6. 扩展FlashInfer工具函数:在vllm/utils/flashinfer.py中修改flashinfer_mm_fp4flashinfer_scaled_fp4_mm,增加block_sizeuse_nvfp4参数;新增flashinfer_mxfp4_quantize自定义操作(支持fake tensor注册),用于激活量化。
  7. 测试:在tests/quantization/test_compressed_tensors.py中添加test_compressed_tensors_mxfp4测试,验证MXFP4模型加载和前向。
文件 模块 状态 重要度
vllm/model_executor/kernels/linear/mxfp4/flashinfer.py FlashInfer 后端 added 8.97
vllm/model_executor/kernels/linear/mxfp4/base.py 量化内核 added 8.93
vllm/model_executor/kernels/linear/mxfp4/marlin.py Marlin 后端 added 8.81
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py 压缩张量方案 renamed 7.97
vllm/model_executor/kernels/linear/__init__.py 内核注册 modified 7.71
vllm/utils/flashinfer.py FlashInfer 工具 modified 7.5
tests/quantization/test_compressed_tensors.py 测试 modified 6.06

关键符号

init_mxfp4_linear_kernel FlashInferMxFp4LinearKernel.is_supported FlashInferMxFp4LinearKernel.process_weights_after_loading FlashInferMxFp4LinearKernel.apply_weights MarlinMxFp4LinearKernel.is_supported MarlinMxFp4LinearKernel.apply_weights flashinfer_mxfp4_quantize flashinfer_scaled_fp4_mm CompressedTensorsW4A4Mxfp4.__init__ CompressedTensorsW4A4Mxfp4.process_weights_after_loading

关键源码片段

vllm/model_executor/kernels/linear/mxfp4/flashinfer.py data-contract

新增 FlashInfer MXFP4 内核实现,为 Blackwell 设备提供 W4A4 激活量化的 GEMM 路径,是性能关键路径。

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM projectimport torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import swizzle_mxfp4_scales
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutedsl
from .base import MxFp4LinearKernel, MxFp4LinearLayerConfig_MXFP4_GROUP_SIZE = 32 # 组大小固定为 32class FlashInferMxFp4LinearKernel(MxFp4LinearKernel):
    """MXFP4 W4A4 GEMM via FlashInfer CUTLASS (SM100+)."""
​
    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        # 需要 Blackwell 架构(sm_100)且安装 flashinfer cutedsl
        if current_platform.has_device_capability(100) and has_flashinfer_cutedsl():
            return True, None
        return False, "FlashInfer + >=sm_100 (Blackwell) required"
​
    @classmethod
    def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]:
        # 当前不检查 config 详细字段,直接表示可以
        return True, None
​
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        N, scale_K = layer.weight_scale.shape
        K = scale_K * _MXFP4_GROUP_SIZE
        # swizzle 并填充 N 至 128 的倍数以满足 CUTLASS tile 要求
        padded_N = ((N + 127) // 128) * 128
        layer.weight_scale = Parameter(
            swizzle_mxfp4_scales(layer.weight_scale.data, N, K).reshape(padded_N, -1),
            requires_grad=False,
        )
​
    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        from vllm.utils.flashinfer import flashinfer_mxfp4_quantize, flashinfer_scaled_fp4_mm
        weight = layer.weight
        out_shape = x.shape[:-1] + (layer.output_size_per_partition,)
        x_2d = x.reshape(-1, x.shape[-1])
        # 动态量化激活
        x_fp4, x_scale = flashinfer_mxfp4_quantize(x_2d)
        out = flashinfer_scaled_fp4_mm(
            x_fp4, weight, x_scale, layer.weight_scale,
            alpha=None, out_dtype=x.dtype,
            backend="cute-dsl",
            block_size=_MXFP4_GROUP_SIZE,
            use_nvfp4=False, # 使用 mx 格式,而非 nvfp4
        )
        if bias is not None:
            out = out + bias
        return out.view(out_shape)
vllm/model_executor/kernels/linear/mxfp4/base.py data-contract

定义 MXFP4 线性层的抽象基类和配置数据类,是内核后端的统一契约。

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM projectfrom abc import ABC, abstractmethod
from dataclasses import dataclass
import torch@dataclass
class MxFp4LinearLayerConfig:
    """Configuration for an MXFP4 linear layer.
    All MXFP4 layers share the same structure: packed uint8 weights (2 FP4 values per
    byte) and per-block weight scales (group size 32).
    """
    pass # Placeholder for future extensionsclass MxFp4LinearKernel(ABC):
    """Base class for MXFP4 quantized linear kernels.
    Each subclass implements a specific GEMM backend (CUTLASS, Marlin, etc).
    The kernel selection mechanism iterates over registered subclasses in
    priority order, calling ``is_supported`` and ``can_implement`` to find the best
    match for the current hardware.
    """
    def __init__(self, config: MxFp4LinearLayerConfig) -> None:
        # 确保子类满足约束
        assert self.can_implement(config)[0]
        assert self.is_supported()[0]
        self.config = config
​
    @classmethod
    @abstractmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        """Return whether this kernel can run on the current platform."""
        ...
​
    @classmethod
    @abstractmethod
    def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]:
        """Return whether this kernel can handle *config*."""
        ...
​
    @abstractmethod
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Transform weights into the format required by this kernel.
        Called once after checkpoint weights have been loaded onto the
        device. Implementations should repack / swizzle / pad weights
        and scales in-place on *layer*.
        """
        ...
​
    @abstractmethod
    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Run the quantized GEMM."""
        ...

评论区精华

swizzle_mxfp4_scales reshape 尺寸问题 正确性

@yewentao256 指出 `swizzle_mxfp4_scales` 内部已对 N 填充至 128 的倍数,但 `process_weights_after_loading` 仍使用原始 N 进行 reshape 操作,可能导致形状不匹配。@gemini-code-assist[bot] 也指出同样问题。

结论:作者 @dsikka 在后续 commit 中修改为使用 padded_N 进行 reshape,问题已修复。 · 已解决

使用 has_flashinfer_cutedsl 替代 has_flashinfer 设计

@yewentao256 建议导入 `has_flashinfer_cutedsl` 而不是通用的 `has_flashinfer`,以更精确地表示对 cutedsl 的依赖。

结论:作者采纳建议,并在后续 commit 中修改。 · 已解决

风险与影响

  1. FlashInfer swizzle reshape兼容性FlashInferMxFp4LinearKernel.process_weights_after_loading中,swizzle_mxfp4_scales会填充N至128的倍数,但若进行其他未覆盖的reshape(如layer.weight等),可能仍存在形状不匹配。已在commit中修复,但建议监控用户反馈。
  2. 激活量化额外开销:FlashInfer路径对激活动态量化,增加计算和内存带宽开销,可能在小batch时性能退化。需通过配置或环境变量允许用户选择Marlin回退。
  3. 方案名称变更破坏性CompressedTensorsW4A16Mxfp4重命名为W4A4Mxfp4,旧类名不再导出,可能破坏依赖旧名称的外部代码。建议在文档或changelog中说明迁移路径。
  4. Marlin内核依赖外部模块MarlinMxFp4LinearKernel直接依赖marlin_utils_fp4,该模块可能变化,需确保接口稳定。

影响用户:使用compressed-tensors量化MXFP4模型的用户在Blackwell设备上将获得W4A4推理性能提升(激活量化降低带宽),其他设备兼容W4A16。需注意类名变更。影响系统:新增内核抽象和选择逻辑,增加少量初始化开销但无运行时影响。影响团队:提供了可扩展的内核注册机制,便于未来添加新量化格式后端。影响范围:中等,仅涉及量化模型加载和线性层计算路径,非核心调度或通信路径。

swizzle padding 兼容风险 激活量化性能开销 方案重命名破坏风险

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论