Prhub

#43332 [MoE/b12x] Accept W4A16 (kNvfp4Static, None) in FlashInferB12xExperts supports check

原始 PR 作者 ECMGit 合并时间 2026-06-03 06:20 文件变更 1 提交数 4 评论 11 代码增减 +27 / -4

执行摘要

b12x MoE 后端支持 W4A16 NVFP4 检查点

FlashInferB12xExperts._supports_quant_scheme 原本要求 activation_key 必须为 kNvfp4Dynamic,导致每个 W4A16 NVFP4 检查点(如 nvidia/Qwen3.6-35B-A3B-2.06GB-per-token)都被 dispatcher 拒绝,被迫使用 Marlin。PR body 指出 b12x 内核已经在内部处理 BF16→FP4 激活量化,因此 W4A16 检查点在运行时是兼容的,但元数据门控过于严格。

此 PR 值得精读,因为它展示了一个精心设计的元数据兼容性修复,同时也体现了在热路径中避免动态分配的良好实践。

讨论亮点

gemini-code-assist[bot] 指出虽然 dispatcher 现在可以选中 b12x 后端,但 apply 方法中 assert self.a2_gscale is not None 会在 W4A16 检查点上触发 AssertionError。ECMGit 回复说后续提交已修复。vadiklyutiy 建议将 fc2_input_scale 分配移到 apply 热路径之外,避免运行时分配;ECMGit 接受并改为在 process_weights_after_loading 中提前缓存所有 scale 张量。

实现拆解

  1. 修改 _supports_quant_scheme(文件 flashinfer_b12x_moe.py 第 131-134 行):将原来的严格相等检查 (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic) 改为包含 (kNvfp4Static, None) 的成员检查,以允许 W4A16 NVFP4 检查点通过。
  2. 扩展 process_weights_after_loading(文件 flashinfer_b12x_moe.py 第 58-127 行):在权重加载后处理中,如果 self.a2_gscale 为 None(W4A16 检查点的特征),则创建一个形状为 (num_local_experts,)、值为 1.0 的 fc2_input_scale 张量,并缓存到 self._fc2_input_scale;否则复用现有的 a2_gscale(已置 1)。这样 apply() 在热路径中可以直接使用缓存张量,避免每次推理时检查 a2_gscale 是否为 None。
  3. 更新 apply 方法(文件 flashinfer_b12x_moe.py):将断言从 assert self.a2_gscale is not None 改为 assert self._fc2_input_scale is not None,并将传递给内核的 fc2_input_scale 从 self.a2_gscale 替换为 self._fc2_input_scale。
  4. 添加实例属性 _fc2_input_scaleinit 中):类型为 torch.Tensor | None,初始化 None,确保 apply 中的断言能正确检查。
文件 模块 状态 重要度
vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py MoE 专家层 modified 7.07

关键符号

FlashInferB12xExperts.__init__ FlashInferB12xExperts.process_weights_after_loading FlashInferB12xExperts._supports_quant_scheme FlashInferB12xExperts.apply

关键源码片段

vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py data-contract

核心变更文件,修改了 _supports_quant_scheme、__init__、process_weights_after_loading 和 apply 方法,实现 W4A16 检查点的兼容与 scale 张量的提前分配。

# 文件 : vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py
# 上下文 : 在 __init__ 中添加 _fc2_input_scale 属性,用于缓存 fc2 输入 scale
class FlashInferB12xExperts(mk.FusedMoEExpertsModular):
    # ...
    def __init__(
        self,
        moe_config: FusedMoEConfig,
        quant_config: FusedMoEQuantConfig,
    ):
        super().__init__(moe_config=moe_config, quant_config=quant_config)
        # ...
        # FC2 input scale tensor bound in process_weights_after_loading: the
        # calibrated (now-zeroed) a2_gscale for static-quant checkpoints, or
        # a synthesized uniform-1.0 tensor for W4A16 checkpoints that lack
        # one. Holding it on the instance keeps apply() alloc-free.
        self._fc2_input_scale: torch.Tensor | None = None
​
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # ... (normalize block scales) ...
        if self.a2_gscale is not None:
            self.a2_gscale.fill_(1.0)
            self._fc2_input_scale = self.a2_gscale
        else:
            # W4A16 NVFP4 checkpoints have no calibrated a2_gscale; b12x
            # performs dynamic per-block FC2-input quantization, so a uniform
            # 1.0 scale per expert is equivalent to the bake-in above for
            # static-quant checkpoints. Allocate once here so apply() stays
            # alloc-free.
            self._fc2_input_scale = torch.ones(
                self.num_local_experts,
                device=layer.w13_weight.device,
                dtype=torch.float32,
            )
​
    @staticmethod
    def _supports_quant_scheme(
        weight_key: QuantKey | None,
        activation_key: QuantKey | None,
    ) -> bool:
        # b12x performs in-kernel BF16->FP4 activation quant, so W4A16
        # NVFP4 checkpoints (activation_key=None, e.g. mixed-precision
        # compressed-tensors layouts) are runtime-compatible.
        return (weight_key, activation_key) in (
            (kNvfp4Static, kNvfp4Dynamic),
            (kNvfp4Static, None),
        )
​
    def apply(self, ...):
        # ...
        assert self._fc2_input_scale is not None, (
            "_fc2_input_scale must be set by process_weights_after_loading"
        )
        # ... pass self._fc2_input_scale to the kernel

评论区精华

apply 中 a2_gscale 断言在 W4A16 上会失败 正确性

gemini-code-assist[bot] 指出即使 dispatcher 现在允许 b12x 用于 W4A16,但 apply 中 assert self.a2_gscale is not None 仍会引发错误。

结论:ECMGit 在后续提交中修复:将 a2_gscale 引用替换为 _fc2_input_scale,并在 process_weights_after_loading 中为 W4A16 预分配全 1 张量。 · 已解决

fc2_input_scale 分配应移出热路径 性能

vadiklyutiy 建议将 scale 张量的分配从 apply(热路径)移到早期阶段,以避免每次调用时分配内存。

结论:ECMGit 将分配移至 process_weights_after_loading 中,仅在权重加载后执行一次,apply 中直接使用缓存张量。 · 已解决

风险与影响

风险极低。此 PR 仅修改元数据检查逻辑和 scale 张量的来源,不涉及任何内核代码或算子重写。存在回归的唯一场景是如果某个 W4A16 检查点实际上确实需要 activation scale 为非 None 值,但 b12x 内核动态量化假设已覆盖此情况。测试计划在 DGX Spark 上进行了端到端验证,吞吐量持平 Marlin。

对用户:使用 W4A16 NVFP4 检查点的用户现在可以在 SM12x 设备上使用 b12x MoE 后端,获得接近 Marlin 的吞吐性能(实测 91.00 tok/s vs Marlin 92.26 tok/s)。对系统:无性能退化(scale 张量提前分配),对 Marlin 回退路径无影响。对团队:补全了 W4A16 与 b12x 的集成,与 PR #42566 形成互补。

最小化变更 无内核修改 跨 PR 依赖

关联 Issue

#42566 [Quantization][ModelOpt] W4A16 NVFP4 fused MoE + mixed-precision dispatch

完整报告

参与讨论