执行摘要
- 一句话:重构 DeepEP 调度器,引入结构化输出 dtype 控制
- 推荐动作:该 PR 值得精读,尤其是自动 dtype 检测函数的设计和量化配置传递方式。对于调度器重构部分,可以学习如何将分散的环境变量集中为结构化的枚举和服务器参数。建议关注全局环境变量问题的后续修复。
功能与动机
之前的调度管道存在多个缺陷:硬编码 FP8 假设导致非 FP8 模型(如 BF16 预训练模型)输出错误;环境变量 SGLANG_DEEPEP_BF16_DISPATCH 工作区脆弱,未在所有代码路径中正确传播;缺少 scheme-aware 的调度,导致 ModelSlimW4A4Int4MoE 等量化方案需要手动调整。详见 PR 描述中的 "Hardcoded FP8 assumption"、"Fragile env-var workaround" 和 "No scheme‑aware dispatch"。
实现拆解
- 定义枚举与自动检测函数:在
moe/utils.py 中新增 DeepEPOutputDtype 枚举(BF16/FP8/INT8/NVFP4)和 get_deepep_output_dtype() 函数,该函数按优先级依次检查:服务器参数 -> 废弃的环境变量 -> 量化配置中的 input_global_scale(NVFP4)-> 量化配置中的 dispatcher_output_dtype -> 后端类型(flashinfer_cutedsl/cutlass 要求 BF16)-> NPU 默认 BF16 -> 其他默认 FP8。
- 新增服务器参数:在
server_args.py 中添加 --deepep-dispatcher-output-dtype,接受 auto、bf16、fp8、int8、nvfp4 选项,默认 auto。
- 重构调度器基类:在
deepep.py 的 _DeepEPDispatcherImplBase.__init__ 中调用 set_deepep_dispatcher_dtype(),该方法从 get_deepep_output_dtype 获取 dtype 并设置 self.use_fp8 和 self.use_nvfp4 标志,同时提供 _validate_and_adjust_dtype(硬件校验与回退)和 _update_int8_quant_env(NPU INT8 环境变量设置)。dispatch_a 和 _dispatch_core 中不再局部计算 use_fp8,直接使用实例变量。
- NPU 量化方法增强:在
fused_moe_method_npu.py 的 NPUW4A4Int4DynamicMoEMethod 等类中添加 apply_without_routing_weights 方法,实现无路由权重的前向计算;同时在 process_weights_after_loading 中通过 layer.dispatcher.set_quant_config 传递正确的 dispatcher_output_dtype(如 "bf16" 或 "int8"),确保调度器知道该用哪个 dtype。
- 模型清理:移除
qwen3_5_mtp.py、qwen3_next_mtp.py、deepseek_nextn.py 等文件中旧的 envs.SGLANG_DEEPEP_BF16_DISPATCH.override 上下文管理器,改为自动检测,简化代码。
- 测试与文档:更新了 4-GPU 测试和手动测试中的参数名,修正文档中的环境变量引用,并新增了 int8 dtype 的文档说明。
关键文件:
python/sglang/srt/layers/moe/utils.py(模块 MoE工具层;类别 source;类型 core-logic;符号 DeepEPOutputDtype, get_deepep_output_dtype): 核心变更文件:定义了 DeepEPOutputDtype 枚举和自动检测函数 get_deepep_output_dtype(),是整个重构的基石。
python/sglang/srt/layers/moe/token_dispatcher/deepep.py(模块 调度器;类别 source;类型 dependency-wiring;符号 set_deepep_dispatcher_dtype, _validate_and_adjust_dtype, _update_int8_quant_env): 调度器基类重构,新增 set_deepep_dispatcher_dtype 方法,在初始化时自动设置 use_fp8/use_nvfp4 标志,并调整硬件兼容性。
python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py(模块 NPU 量化;类别 source;类型 core-logic;符号 apply_without_routing_weights): NPU 量化方法增强:为 NPUW4A4Int4DynamicMoEMethod 等类添加 apply_without_routing_weights 方法,并在 process_weights_after_loading 中通过 dispatcher.set_quant_config 传递 dtype,实现量化感知。
关键符号:get_deepep_output_dtype, set_deepep_dispatcher_dtype, _validate_and_adjust_dtype, _update_int8_quant_env, apply_without_routing_weights
关键源码片段
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
调度器基类重构,新增 set_deepep_dispatcher_dtype 方法,在初始化时自动设置 use_fp8/use_nvfp4 标志,并调整硬件兼容性。
# python/sglang/srt/layers/moe/token_dispatcher/deepep.py ( 部分 )
class _DeepEPDispatcherImplBase:
# ... __init__ 初始化 self.quant_config 等
def __init__(self, **kwargs):
# ... 其他初始化
self.quant_config: Optional[dict] = None
self.set_deepep_dispatcher_dtype() # 新增:集中设置输出 dtype
def set_quant_config(self, quant_config: dict) -> None:
self.quant_config = quant_config
self.set_deepep_dispatcher_dtype() # 当量化配置变化时重新设置
def set_deepep_dispatcher_dtype(self) -> None:
# 调用自动检测函数获取输出 dtype
self.deepep_output_dtype = get_deepep_output_dtype(self)
# dtype 到标志的映射表
config_map = {
DeepEPOutputDtype.BF16: {"use_fp8": False, "use_nvfp4": False},
DeepEPOutputDtype.FP8: {"use_fp8": True, "use_nvfp4": False},
# INT8 用于 NPU A2/A3,虽然 use_fp8 为 True 但实际会走 int8 量化
DeepEPOutputDtype.INT8: {"use_fp8": True, "use_nvfp4": False},
DeepEPOutputDtype.NVFP4: {"use_fp8": False, "use_nvfp4": True},
}
# 根据硬件进行校验和调整(例如 NPU 上 FP8 回退为 INT8)
self._validate_and_adjust_dtype()
# 应用配置
config = config_map[self.deepep_output_dtype]
self.use_fp8 = config["use_fp8"]
self.use_nvfp4 = config["use_nvfp4"]
# NPU 下设置环境变量(供底层库使用)
if _is_npu:
self._update_int8_quant_env()
def _validate_and_adjust_dtype(self) -> None:
"""根据硬件校验 dtype 并在必要时回退"""
if _is_npu:
if self.deepep_output_dtype == DeepEPOutputDtype.FP8:
logger.warning_once(
"Ascend A2/A3 NPU 不支持 fp8 deepep_dispatcher_output_dtype,切换到 int8..."
)
self.deepep_output_dtype = DeepEPOutputDtype.INT8
elif self.deepep_output_dtype == DeepEPOutputDtype.NVFP4:
raise RuntimeError(
"Ascend A2/A3 NPU 不支持 nvfp4 deepep_dispatcher_output_dtype."
)
else:
if self.deepep_output_dtype == DeepEPOutputDtype.INT8:
logger.warning_once(
"GPU 不支持 int8 deepep_dispatcher_output_dtype,切换到 fp8..."
)
self.deepep_output_dtype = DeepEPOutputDtype.FP8
# NVFP4 在 GPU 上支持,无需调整
def _update_int8_quant_env(self) -> None:
"""更新 NPU int8 量化所需的环境变量"""
if self.use_fp8:
os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "1"
else:
os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "0"
python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py
NPU 量化方法增强:为 NPUW4A4Int4DynamicMoEMethod 等类添加 apply_without_routing_weights 方法,并在 process_weights_after_loading 中通过 dispatcher.set_quant_config 传递 dtype,实现量化感知。
# python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py ( 部分 )
class NPUW4A4Int4DynamicMoEMethod(_NPUFusedMoEMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# ... 权重处理 ( 格式转换、打包等 )
# 关键新增:通知调度器该层应使用 BF16 调度 ( 因为 DeepEP 不支持 int4 量化 )
if hasattr(layer, "dispatcher"):
layer.dispatcher.set_quant_config({"dispatcher_output_dtype": "bf16"})
def apply_without_routing_weights(
self,
layer,
hidden_states,
hidden_states_scale,
group_list_type,
group_list,
output_dtype,
):
# 对输入进行 int4 动态量化
hidden_states, hidden_states_scale = torch.ops.npu.npu_dynamic_quant(
hidden_states, dst_type=torch.quint4x2
)
# gmm1: up_gate_proj
hidden_states = torch.ops.npu.npu_grouped_matmul(
x=[hidden_states],
weight=[layer.w13_weight],
scale=[layer.w13_weight_scale],
per_token_scale=[hidden_states_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# act_fn: swiglu
hidden_states = torch.ops.npu.npu_swiglu(hidden_states)
hidden_states, pertoken_scale = torch.ops.npu.npu_dynamic_quant(hidden_states)
# gmm2: down_proj
hidden_states = torch.ops.npu.npu_grouped_matmul(
x=[hidden_states],
weight=[layer.w2_weight],
scale=[layer.w2_weight_scale.to(output_dtype)],
per_token_scale=[pertoken_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
return hidden_states
评论区精华
风险与影响
关联脉络
- PR #22918 [FlashInfer v0.6.11] [RL] Support FlashInfer per-token NVFP4 MoE: 该 PR 引入了 NVFP4 量化支持,与本 PR 的 NVFP4 调度输出 dtype 直接相关,共同为量化 MoE 模型提供完整支持。
- PR #25396 fix: fix deepseek v4 CP error: 同为 deepseek 模型修复,涉及 MoE 层的正确性,与本 PR 的调度器重构共同提升了 deepseek 系列模型的稳定性和量化兼容性。
参与讨论