Prhub

#41986 [Bugfix] Add swiglu limits to deepgemm fp8 methods

原始 PR 作者 zyongye 合并时间 2026-05-15 06:43 文件变更 13 提交数 6 评论 2 代码增减 +85 / -1

执行摘要

修复 DeepGemm FP8 MoE 遗漏 SwiGLU 截断限制

DeepGemmExperts._act_mul_quant 在 FP8 路径中未将模型的 swiglu_limit 传递给 SiLU/Mul 融合核,导致 SwiGLU 输出无界,在子句边界处产生罕见词汇 token(如 id 83480 "Skip")。此补丁修复该遗漏,确保截断限制正确应用于激活值。详见 Issue #41935。

值得精读。展示了从模型量化配置到 Triton kernel 的完整参数传播链路,以及在融合操作中处理精度一致性的最佳实践。测试用例的设计(使用高斯分布确保 clamp 分支被触发)也为类似正确性修复提供了参考。

讨论亮点

在 review 中,gemini-code-assist[bot] 指出当 HAS_CLAMP 为 false 时 mul_in 未转换为 float32,导致与 clamp 路径的精度不一致。作者在后续提交(67ea6ec)中修复:将 mul_in 无条件提升到 fp32 后再做 clamping,但让 silu_out * mul_in 保持在输入精度进行,从而在不偏离 C++ silu_and_mul 参考的前提下统一了精度路径。

实现拆解

  1. Triton kernel 添加 clamp 分支:在 vllm/model_executor/layers/quantization/utils/fp8_utils.py_silu_mul_per_token_group_quant_fp8_colmajor 内核中新增 clamp_limit 参数和 HAS_CLAMP 编译时常量。当 HAS_CLAMP 为真时,在 SiLU 前对 gate 和 up 分别进行单侧和双侧截断(tl.minimum(act_in, clamp_limit)tl.clamp(mul_in, -clamp_limit, clamp_limit)),且全部在 fp32 中计算后再转回输入精度,保证与 C++ 参考行为一致。

  2. 专家类读取并传递 clamp_limit:在 vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.pyDeepGemmExperts.__init__ 中保存 quant_config.gemm1_clamp_limit 为实例属性,并在 _act_mul_quant 中同时传递给 fused_silu_mul_fp8_quant_packed(UE8M0 路径)和 silu_mul_per_token_group_quant_fp8_colmajor(非 UE8M0 路径)。

  3. 配置构建器统一支持:在 vllm/model_executor/layers/fused_moe/config.pyfp8_w8a8_moe_quant_configmxfp4_w4a8_moe_quant_config 函数中增加 gemm1_clamp_limit 可选参数并将其转发给 FusedMoEQuantConfig。类似地,在 vllm/model_executor/layers/fused_moe/oracle/fp8.pymake_fp8_moe_quant_config 中,确保 MXFP8 分支(block_shape == [1, 32])也传递 swiglu_limit。所有其他量化后端(Quark、CompressedTensors、ModelOpt、Online FP8 等)均同步新增参数传递。

  4. 测试覆盖:在 tests/kernels/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py 中添加 reference_with_clamp 参考实现(在 SiLU/Mul 前手动截断输入)和参数化测试 test_silu_mul_fp8_quant_deep_gemm_clamp,验证 clamp_limit 分别为 7.0 和 10.0 时内核输出与参考一致。测试使用高斯分布(* 8.0)确保触发 clamp 分支。

文件 模块 状态 重要度
vllm/model_executor/layers/quantization/utils/fp8_utils.py 量化核 modified 6.58
vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py MoE 专家 modified 5.74
tests/kernels/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py 测试 modified 6.21
vllm/model_executor/layers/fused_moe/config.py 配置层 modified 5.28
vllm/model_executor/layers/fused_moe/oracle/fp8.py 量化路由 modified 5.28
vllm/model_executor/layers/quantization/quark/quark_moe.py 量化后端 modified 4.7

关键符号

DeepGemmExperts.__init__ DeepGemmExperts._act_mul_quant _silu_mul_per_token_group_quant_fp8_colmajor silu_mul_per_token_group_quant_fp8_colmajor reference_with_clamp test_silu_mul_fp8_quant_deep_gemm_clamp

关键源码片段

vllm/model_executor/layers/quantization/utils/fp8_utils.py core-logic

核心修复位置,Triton kernel 新增 clamp_limit 参数和 HAS_CLAMP 分支,实现 SwiGLU 截断

@triton.jit
def _silu_mul_per_token_group_quant_fp8_colmajor(
    y_ptr, y_q_ptr, y_s_ptr,
    M, N,
    y_s_col_stride: tl.int64,
    eps,
    clamp_limit, # 新增:clamp 值(HAS_CLAMP=0 时被忽略)
    fp8_min: tl.constexpr,
    fp8_max: tl.constexpr,
    use_ue8m0: tl.constexpr,
    HAS_CLAMP: tl.constexpr, # 新增:编译时常量标志
    GROUP_SIZE: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    # ... (program id, offsets omitted)
    act_in = tl.load(act_in_ptrs)
    mul_in = tl.load(act_in_ptrs + N_2)
​
    if HAS_CLAMP:
        # 在 fp32 中截断 gate(上界)和 up(对称),然后转回输入精度
        # 确保与 C++ silu_and_mul 的 compute() 模板一致
        act_in = tl.minimum(act_in.to(tl.float32), clamp_limit).to(
            y_ptr.dtype.element_ty
        )
        mul_in = tl.clamp(mul_in.to(tl.float32), -clamp_limit, clamp_limit).to(
            y_ptr.dtype.element_ty
        )
    act_in = act_in.to(tl.float32)
    one_f32 = tl.cast(1, tl.float32)
    silu_out = (act_in / (one_f32 + tl.exp(-act_in))).to(y_ptr.dtype.element_ty)
    y = (silu_out * mul_in).to(tl.float32) # 乘法在输入精度进行
​
​
def silu_mul_per_token_group_quant_fp8_colmajor(
    input: torch.Tensor,
    output: torch.Tensor | None = None,
    use_ue8m0: bool | None = None,
    eps: float = 1e-10,
    clamp_limit: float | None = None, # 新增:传入 None 表示不截断
):
    # ...
    has_clamp = clamp_limit is not None
    grid = (M // BLOCK_M, N_2 // BLOCK_N)
    _silu_mul_per_token_group_quant_fp8_colmajor[grid](
        input, output, output_scales,
        M, N, output_scales.stride(-1),
        eps,
        clamp_limit if has_clamp else 0.0, # 无截断时传 0.0(不会被使用)
        fp8_min, fp8_max, use_ue8m0,
        has_clamp, # 布尔值作为 constexpr 传入
        GROUP_SIZE, BLOCK_M, BLOCK_N,
    )
vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py data-contract

将 gemm1_clamp_limit 从量化配置传递给两个融合路径

class DeepGemmExperts(mk.FusedMoEExpertsModular):
    def __init__(self, moe_config, quant_config):
        # ...
        # 新增:读取量化配置中的 gemm1_clamp_limit,后续传递给融合 kernel
        self.gemm1_clamp_limit = quant_config.gemm1_clamp_limit
​
    def _act_mul_quant(self, input, output, activation):
        # ...
        if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
            if activation == MoEActivation.SILU:
                return fused_silu_mul_fp8_quant_packed(
                    input=input, output_q=output, group_size=block_k,
                    clamp_limit=self.gemm1_clamp_limit, # 新增
                )
        # non-UE8M0 SiLU 路径
        if activation == MoEActivation.SILU:
            use_ue8m0 = ...
            return silu_mul_per_token_group_quant_fp8_colmajor(
                input=input, output=output, use_ue8m0=use_ue8m0,
                clamp_limit=self.gemm1_clamp_limit, # 新增
            )
tests/kernels/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py test-coverage

新增 test_silu_mul_fp8_quant_deep_gemm_clamp 测试用例和 reference_with_clamp 参考实现

def reference_with_clamp(
    x: torch.Tensor, use_ue8m0: bool, clamp_limit: float
) -> tuple[torch.Tensor, torch.Tensor]:
    """Pre-clamp inputs (gate from above, up symmetric) at the input dtype
    to match the C++ compute() template, then run standard reference."""
    N_2 = x.size(1) // 2
    dtype = x.dtype
    gate = x[..., :N_2].to(torch.float32).clamp(max=clamp_limit).to(dtype)
    up = x[..., N_2:].to(torch.float32).clamp(min=-clamp_limit, max=clamp_limit).to(dtype)
    return reference(torch.cat([gate, up], dim=-1), use_ue8m0)
​
​
@pytest.mark.parametrize("clamp_limit", [7.0, 10.0])
@pytest.mark.skipif(current_platform.is_rocm(), reason="ROCm does not support DeepGemm.")
def test_silu_mul_fp8_quant_deep_gemm_clamp(T, N, clamp_limit):
    set_random_seed(42)
    # 使用高斯分布(* 8.0)确保值超出 clamp_limit,触发 clamp 分支
    input = torch.randn((T, N), dtype=torch.bfloat16, device="cuda") * 8.0
    use_ue8m0 = is_deep_gemm_e8m0_used()
    output, output_scales = silu_mul_per_token_group_quant_fp8_colmajor(
        input, use_ue8m0=use_ue8m0, clamp_limit=clamp_limit
    )
    ref_output, ref_output_scales = reference_with_clamp(
        input, use_ue8m0, clamp_limit
    )
    torch.testing.assert_close(
        output.to(torch.float32), ref_output.to(torch.float32)
    )
    torch.testing.assert_close(output_scales, ref_output_scales)

评论区精华

mul_in 精度一致性 正确性

gemini-code-assist[bot] 指出当 HAS_CLAMP 为 false 时 mul_in 未转换为 float32,与 clamp 路径的精度不一致。

结论:作者在提交 67ea6ec 中修复,将 mul_in 无条件转换为 float32,同时保持 silu*mul 乘法在输入精度以避免偏离 C++ 参考路径。 · 已解决

风险与影响

主要风险是 clamp 分支引入了编译时条件分支,但 HAS_CLAMPtl.constexpr,不会带来运行时开销。精度一致性的调整(mul_in 提升到 fp32 再 clamp,然后转回输入精度)可能轻微改变非 clamp 路径的输出,但测试已经确认与参考实现吻合。该变更仅影响使用 DeepGemm FP8 量化路径且具有 swiglu_limit 的模型(如 DeepSeek-V4-Flash-Base),其他模型无行为变化。配置传播的链路较长,若某个量化后端忘记添加 gemm1_clamp_limit 参数,会导致 clamp 静默失效,但本次已覆盖所有已知后端。

对用户:使用 DeepSeek-V4-Flash-Base 等具有 swiglu_limit 的模型将正确输出,不再产生异常 token。对性能:新增 constexpr 分支不影响运行时效率;精度统一调整仅在非 clamp 路径引入一次 fp32 转换再转回,开销可忽略。对系统:所有 FP8 和 W4A8 量化配置构建器现在都支持 gemm1_clamp_limit,确保了该功能在切换到不同后端时行为一致。

kernel 变更 配置传递 精度对齐

关联 Issue

#41935 [Bug]: DeepGemmExperts FP8 MoE drops swiglu_limit (DeepSeek-V4-Flash-Base produces glitch tokens)

完整报告

参与讨论