执行摘要
- 一句话:WNA16 MoE 后端选择重构至 oracle 模块,新增 FlashInfer Monolithic 支持
- 推荐动作:值得精读,特别是 oracle 模式的设计和 kernel 实例存储位置的决策。关注 review 中关于 state sharing 的修改,以及后续的兼容性修复。
功能与动机
源自 #39190,旨在将 WNA16 MoE 量化方法的后端选择逻辑集中到 oracle 模块,消除 CompressedTensorsWNA16MarlinMoEMethod 中硬编码的 kernel 选择分支,使得添加新后端(如 FlashInfer TRT-LLM)更为简洁且可维护。
实现拆解
- 新增 TrtLlmMxint4ExpertsMonolithic 类(
vllm/model_executor/layers/fused_moe/experts/trtllm_mxint4_moe.py):继承 mk.FusedMoEExpertsMonolithic,实现 _supports_current_device、_supports_quant_scheme、_supports_parallel_config 等兼容性检查方法,封装 FlashInfer 的 flashinfer_trtllm_mxint4_moe 调用。
- 扩展 oracle 模块(
vllm/model_executor/layers/fused_moe/oracle/int_wna16.py):新增 FLASHINFER_TRTLLM 后端枚举,在 select_wna16_moe_backend 中将其加入优先级列表;更新 backend_to_kernel_cls 以返回 TrtLlmMxint4ExpertsMonolithic;新增 make_wna16_moe_quant_config 和 _process_weights_flashinfer 辅助函数。
- 重构 CompressedTensorsWNA16MarlinMoEMethod(
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py):删除 GPTQMarlinState 枚举和直接 flashinfer/marlin 分支,改用 select_wna16_moe_backend 获得 experts_cls;在 process_weights_after_loading 中根据后端调用 make_wna16_moe_kernel 构造 kernel 实例(存储于 layer 上以避免共享状态问题)。
- 更新配置工厂函数(
vllm/model_executor/layers/fused_moe/config.py):为 int4_w4a16_moe_quant_config 和 int8_w8a16_moe_quant_config 增加 a1_gscale/a2_gscale 参数,移除未使用的 awq_marlin_moe_quant_config。
- 同步调用方:在
awq_marlin.py 和 auto_gptq.py 中将 select_wna16_moe_backend 调用更新为新签名,并使用 make_wna16_moe_quant_config 替代旧配置函数;在 marlin_moe.py 中添加 kInt4Static32 和 int8_w8a16 支持标记。
- 测试配套:PR 描述说明“现有测试应保持一致”,未引入新测试。
关键文件:
vllm/model_executor/layers/fused_moe/experts/trtllm_mxint4_moe.py(模块 MoE 专家层;类别 source;类型 data-contract;符号 TrtLlmMxint4ExpertsMonolithic, init, _supports_current_device, _supports_no_act_and_mul): 新增文件,封装 FlashInfer TRT-LLM MxInt4 MoE 的 Monolithic 内核接口,是本次重构新增后端的关键实现。
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py(模块 量化方法;类别 source;类型 data-contract;符号 GPTQMarlinState, select_gemm_impl, is_monolithic): 核心重构文件,将原来硬编码的 kernel 选择改为通过 oracle 获取 experts_cls,并解耦了权重后处理逻辑。
vllm/model_executor/layers/fused_moe/oracle/int_wna16.py(模块 Oracle 调度器;类别 source;类型 data-contract;符号 make_wna16_moe_quant_config, _process_weights_flashinfer): Oracle 模块,集中管理 WNA16 MoE 后端选择逻辑和 kernel 构造。
vllm/model_executor/layers/fused_moe/config.py(模块 配置;类别 source;类型 data-contract;符号 awq_marlin_moe_quant_config): 更新量化配置工厂函数,为 int4/int8 MoE 配置添加激活全局 scale 参数,并移除遗弃的 awq_marlin_moe_quant_config。
vllm/model_executor/layers/quantization/awq_marlin.py(模块 量化方法;类别 source;类型 data-contract): 适配 oracle 接口修改,使用 make_wna16_moe_quant_config 替代已移除的 awq_marlin_moe_quant_config。
vllm/model_executor/layers/fused_moe/experts/marlin_moe.py(模块 MoE 专家层;类别 source;类型 data-contract): 添加 kInt4Static32 支持,扩展 int8_w8a16 支持标记。
关键符号:TrtLlmMxint4ExpertsMonolithic.init, TrtLlmMxint4ExpertsMonolithic._supports_quant_scheme, TrtLlmMxint4ExpertsMonolithic.apply, select_wna16_moe_backend, make_wna16_moe_kernel, make_wna16_moe_quant_config, _process_weights_flashinfer, CompressedTensorsWNA16MarlinMoEMethod.process_weights_after_loading, CompressedTensorsWNA16MarlinMoEMethod.apply
关键源码片段
vllm/model_executor/layers/fused_moe/experts/trtllm_mxint4_moe.py
新增文件,封装 FlashInfer TRT-LLM MxInt4 MoE 的 Monolithic 内核接口,是本次重构新增后端的关键实现。
# vllm/model_executor/layers/fused_moe/experts/trtllm_mxint4_moe.py
# 该类封装 FlashInfer 的 fused router + experts Monolithic 内核
class TrtLlmMxint4ExpertsMonolithic(mk.FusedMoEExpertsMonolithic):
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
super().__init__(moe_config, quant_config)
# 从配置中提取常用参数
self.topk = moe_config.experts_per_token
self.intermediate_size_per_partition = moe_config.intermediate_size_per_partition
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.ep_rank
self.routing_method = moe_config.routing_method
@staticmethod
def _supports_quant_scheme(weight_key: QuantKey | None, activation_key: QuantKey | None) -> bool:
# 仅支持 int4 权重 + 无激活量化,且 group size 为 32
return (weight_key, activation_key) == (kInt4Static32, None)
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
# FlashInfer MxInt4 使用 fused SwiGLU
return activation == MoEActivation.SWIGLUOAI
@property
def expects_unquantized_inputs(self) -> bool:
# 内核内部处理量化,输入应为未量化
return True
def apply(self, hidden_states, w1, w2, router_logits, activation, global_num_experts,
expert_map, a1q_scale, apply_router_weight_on_input, ...) -> torch.Tensor:
# 内部调用 flashinfer_trtllm_mxint4_moe,使用预处理的 scale
return flashinfer_trtllm_mxint4_moe(
x=hidden_states, router_logits=router_logits, ...)
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py
核心重构文件,将原来硬编码的 kernel 选择改为通过 oracle 获取 experts_cls,并解耦了权重后处理逻辑。
# vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_wna16_marlin.py
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__(self, weight_quant, input_quant, moe, layer_name=None):
# ... 解析 weight_quant 参数
weight_key = QuantKey(self.quant_type, scale) # 根据 num_bits 和 group_size 构建 QuantKey
# 通过 oracle 选择后端和 expert 类
self.wna16_backend, self.experts_cls = select_wna16_moe_backend(
config=self.moe, weight_key=weight_key)
def process_weights_after_loading(self, layer):
# ... 权重处理后,构造 kernel 实例并挂载到 layer 上
layer.moe_quant_config = self.get_fused_moe_quant_config(layer)
layer.moe_kernel = make_wna16_moe_kernel(
moe_quant_config=layer.moe_quant_config,
experts_cls=self.experts_cls,
config=self.moe)
layer.moe_kernel.set_weight(
w13_weight=layer.w13_weight, w2_weight=layer.w2_weight,
w13_scale=layer.w13_scale, w2_scale=layer.w2_scale, ...)
def apply(self, layer, hidden_states, ...):
# 使用 layer 上的 kernel 执行
return layer.moe_kernel.apply(hidden_states, ...)
vllm/model_executor/layers/fused_moe/oracle/int_wna16.py
Oracle 模块,集中管理 WNA16 MoE 后端选择逻辑和 kernel 构造。
# vllm/model_executor/layers/fused_moe/oracle/int_wna16.py
class WNA16MoEBackend(Enum):
MARLIN = "MARLIN"
BATCHED_MARLIN = "BATCHED_MARLIN"
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM" # 新增
XPU = "XPU"
def select_wna16_moe_backend(
config: FusedMoEConfig,
weight_key: QuantKey,
) -> tuple[WNA16MoEBackend, type[mk.FusedMoEExperts]]:
# 根据平台和配置按优先级尝试每个后端,检查兼容性
for backend in _get_priority_backends():
experts_cls = backend_to_kernel_cls(backend)
if not experts_cls[0]._supports_parallel_config(config.moe_parallel_config):
continue
if not experts_cls[0]._supports_quant_scheme(weight_key, None):
continue
# ... 后续检查
return backend, experts_cls[0]
raise ValueError("No suitable WNA16 MoE backend")
def make_wna16_moe_quant_config(
num_bits, group_size, w1_scale, w2_scale, ...
) -> FusedMoEQuantConfig:
# 根据 num_bits 和 group_size 构建对应的 FusedMoEQuantConfig
# 替代原先不同量化方法中的重复配置逻辑
评论区精华
风险与影响
关联脉络
- PR #39190 Derived from #39190 (referenced in body): 该 PR 是 #39190 的延续,作为 oracle 重构的基础。
参与讨论