执行摘要
- 一句话:将 SiLU+Mul 与 FP8 块量化融合,提升 MiniMax-M2 MoE 性能
- 推荐动作:值得精读,尤其是条件融合的设计模式。虽然 review 中暴露了 block_shape 类型鲁棒性等细节问题,但整体思路清晰。建议后续开发者注意将
self.block_shape 可能为 None 或 tuple 的类型信息明确化,并考虑为 DeepGEMM E8M0 路径添加等效的 fused kernel 或统一量化接口。
功能与动机
原始 TritonExperts.apply 在完成 W1 矩阵乘后,先显式调用 self.activation(包含 Silu+Mul)再调用 moe_kernel_quantize_input 进行 FP8 分块量化,两次 kernel launch 造成额外显存带宽浪费与调度开销。通过直接调用已存在的 ops.silu_and_mul_per_block_quant 融合 CUDA kernel,可以在一次 kernel 内完成激活与量化,同时减少一次中间缓存 intermediate_cache2 的读写。
实现拆解
- 新增导入:引入
vllm.model_executor.layers.quantization.utils.fp8_utils.is_deep_gemm_e8m0_used() 函数,用于判断是否使用 DeepGEMM E8M0 量化模式。
- 核心条件分支:在
TritonExperts.apply() 方法的 w1 矩阵乘处理之后,增加一个 if 条件判断:
activation == MoEActivation.SILU(仅门控 SiLU 适用)
self.quant_config.use_fp8_w8a8(仅 FP8 W8A8 量化)
self.block_shape == [128, 128](仅 128x128 块形状)
lora_context is None(LoRA 场景下需要显式保留中间结果)
not is_deep_gemm_e8m0_used()(DeepGEMM E8M0 模式下此 kernel 不兼容)
- 融合路径:当所有条件满足时,调用
ops.silu_and_mul_per_block_quant() 直接对 intermediate_cache1(W1 输出)进行 SiLU+Mul 激活与 FP8 块量化,返回量化结果和缩放因子。
- 回退路径:否则,执行原来的分离流程:先
self.activation() 得到 intermediate_cache2,再调用 moe_kernel_quantize_input() 量化。
- 后续流程不变:无论走哪条路径,得到的
qintermediate_cache2 和 a2q_scale 都继续传递给下游的 w2 矩阵乘核。
- 测试与配置:本次变更未修改测试文件。实测在 4×H800 和 4×H200 上通过 benchmark 验证性能提升和精度保持(GSM8K 准确率未退化)。
关键文件:
vllm/model_executor/layers/fused_moe/experts/triton_moe.py(模块 MoE专家;类别 source;类型 core-logic;符号 TritonExperts.apply): 唯一变更文件,实现了 TritonFP8MoE 中 SiLU+Mul 与 FP8 分块量化的融合路径。
关键符号:TritonExperts.apply, ops.silu_and_mul_per_block_quant
关键源码片段
vllm/model_executor/layers/fused_moe/experts/triton_moe.py
唯一变更文件,实现了 TritonFP8MoE 中 SiLU+Mul 与 FP8 分块量化的融合路径。
# File: vllm/model_executor/layers/fused_moe/experts/triton_moe.py
# ... 在 apply 方法中,完成 w1 矩阵乘后:
a2q_scale: torch.Tensor | None = None
# 当满足以下条件时,使用 fused kernel 一步完成 SiLU+Mul 与 FP8 分块量化
# - 激活函数为门控 SiLU(即 MoEActivation.SILU)
# - 使用 FP8 W8A8 量化
# - 块形状为 [128, 128](group_size=128)
# - 没有 LoRA 需要保留中间结果
# - 未启用 DeepGEMM E8M0 模式(其量化行为不同)
if (
activation == MoEActivation.SILU
and self.quant_config.use_fp8_w8a8
and self.block_shape == [128, 128]
and lora_context is None
and not is_deep_gemm_e8m0_used()
):
# 调用 fused CUDA custom op: silu_and_mul_per_block_quant
qintermediate_cache2, a2q_scale = ops.silu_and_mul_per_block_quant(
intermediate_cache1.view(-1, N), # 输入为 W1 输出,已 reshape
group_size=128,
quant_dtype=current_platform.fp8_dtype(),
)
else:
# 回退路径:分两步执行
# 1) 执行激活(SiLU+Mul)
self.activation(
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
)
# 2) 对激活结果进行 FP8 分块量化
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2,
a2_scale,
self.quant_dtype,
self.per_act_token_quant,
self.block_shape,
quantization_emulation=self.quantization_emulation,
)
# 后续的 w2 矩阵乘保持不变,接收 qintermediate_cache2 与 a2q_scale
invoke_fused_moe_triton_kernel(
qintermediate_cache2, w2, intermediate_cache3, a2q_scale,
self.w2_scale, topk_weights, sorted_token_ids, expert_ids,
num_tokens_post_padded, not apply_router_weight_on_input, 1, config,
compute_type=compute_type,
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a8=self.quant_config.use_int8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=self.w2_bias,
)
# ...
评论区精华
Review 讨论主要聚焦在融合条件的鲁棒性上:
风险与影响
关联脉络
- PR #42855 [Bugfix] Fix DSV4 Base model swiglu limit issue in FP8 path: 同为 FP8 MoE 相关修复,虽然不直接关联,但涉及相似的 fused_moe 模块路径。
参与讨论