执行摘要
- 一句话:支持 Nemotron 模型 NVFP4 权重通过 Marlin W4A16 在 SM80-SM90 上推理
- 推荐动作:建议精读:该 PR 展示了如何将专有量化格式(NVFP4 ModelOpt)映射到已有 Marlin 内核,包含 scale 转换、非门控 MoE 扩展、多后端路由等设计决策,对于理解 SGLang 的量化抽象层和 MoE 支持有参考价值。关注点:scale 转换的数值正确性、非门控 MoE 的激活函数处理、全局 scale 指数偏移的数学推导。
功能与动机
Serve NVFP4 modelopt checkpoints (modelopt_fp4) on Ampere/Hopper by routing them through Marlin W4A16 when native FP4 is not supported. (PR Body 原文)
实现拆解
-
NVFP4 到 Marlin 的 scale 适配层:在 marlin_utils_fp4.py 中新增 nvfp4_marlin_process_scales 将 per-group scale 从 FP16/BF16 转换为 Marlin 所需的 FP8E4M3 格式;新增 nvfp4_marlin_process_global_scale 将全局 scale 进行指数偏移以匹配内核期望;新增 apply_fp4_marlin_linear 作为主入口,通过 gptq_marlin_gemm 调用 Marlin 内核,支持 FP16/BF16 激活;通过 register_custom_op 注册 fake 实现用于 tracing。
-
Dense Linear 集成:在 modelopt_quant.py 的 ModelOptNvFp4LinearMethod 中,create_weights 保存 quant_config 和 params_dtype,process_weights_after_loading 检测当前后端是否为 Marlin,若是则调用 prepare_nvfp4_layer_for_marlin 执行权重和 scale 的 repack,并跳过原生的 Blackwell 路径;apply 方法也优先路由到 apply_fp4_marlin_linear。同时 get_min_capability 从 100 降至 80 以允许 SM80+ 运行。
-
MoE 扩展:fused_marlin_moe.py 中扩展 get_scalar_type 接受 global_scale 参数以匹配 NVFP4 的 scale 语义;新增 is_nvfp4_marlin 检测逻辑;fused_marlin_moe 函数增加 w1_global_scale、w2_global_scale、activation、is_gated 参数,并据此计算 gemm1_n 和控制激活函数的使用(支持 silu 和 relu2)。MarlinMoERunner 在 moe_runner/marlin.py 中传递全局 scale。
-
FP4 GEMM 后端选择:fp4_utils.py 的 Fp4GemmRunnerBackend 新增 MARLIN;initialize_fp4_gemm_config 在 SM80-SM90 且 CUDA 时自动选择 marlin,并优先于其他后端(flashinfer 等仅限 SM100+ 或 CPU)。
-
模型钩子:nemotron_h_hook.py 在检测到 modelopt_fp4 或 modelopt_mixed 量化且 SM80-SM90 时,自动设置 --moe-runner-backend marlin,避免用户手动指定。
-
测试与文档:新增单元测试验证 scale 变换数值正确性(test_gptq_marlin.py)、dense linear 与 dequant 参考的精度(test_gptq_marlin.py)、MoE 非门控 W4A16 和 NVFP4 的端到端精度(test_moe_wna16_marlin.py)。新增端到端模型测试 test_nvidia_nemotron_3_super_nvfp4.py。更新 docs_new/docs/references/fp4_gemm_backend.md 和 server_args 文档。
关键文件:
python/sglang/srt/layers/quantization/marlin_utils_fp4.py(模块 量化层;类别 source;类型 core-logic;符号 nvfp4_marlin_process_scales, nvfp4_marlin_process_global_scale, fake_apply_fp4_marlin_linear, apply_fp4_marlin_linear): 核心新增文件,实现 NVFP4 到 Marlin 的 scale 转换、dense linear apply 函数和权重 prep 逻辑。所有 NVFP4 专用内核适配均在此完成。
python/sglang/srt/layers/quantization/modelopt_quant.py(模块 量化配置;类别 source;类型 data-contract;符号 create_weights, process_weights_after_loading, apply): Dense layer 集成入口,加载权重时路由到 Marlin 准备,apply 时调用 Marlin GEMM。同时降低最低计算能力限制,使 NVFP4 量化的模型可在 SM80+ 上加载。
python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py(模块 MoE内核;类别 source;类型 core-logic;符号 get_scalar_type, fused_marlin_moe): MoE 核心函数扩展,新增 NVFP4 检测、全局 scale 支持、非门控激活函数。使 fused_marlin_moe 能够处理 NVFP4 MoE 层。
python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py(模块 MoE测试;类别 test;类型 test-coverage;符号 test_fused_marlin_moe_non_gated_relu2, test_fused_marlin_moe_nvfp4_non_gated_padded_intermediate_launches, test_fused_marlin_moe_nvfp4_non_gated_matches_dequant_reference): 新增 MoE 非门控 NVFP4 和 W4A16 的单元测试,覆盖 fused_marlin_moe 的 non-gated 路径和精度比对。
python/sglang/jit_kernel/tests/test_gptq_marlin.py(模块 内核测试;类别 test;类型 test-coverage;符号 test_nvfp4_marlin_support_and_scale_transforms_sm80_sm90, test_nvfp4_marlin_dense_matches_dequant_reference): 新增 NVFP4 scale 变换和 dense linear 精度测试,验证 scale 处理正确性以及 apply_fp4_marlin_linear 与 dequant 参考的一致性。
python/sglang/srt/layers/quantization/fp4_utils.py(模块 后端选择;类别 source;类型 core-logic;符号 Fp4GemmRunnerBackend, initialize_fp4_gemm_config): FP4 后端枚举增加 MARLIN,并在初始化函数中为 SM80-SM90 自动选择 Marlin 后端,影响所有 FP4 量化模型的推理路径。
python/sglang/srt/arg_groups/nemotron_h_hook.py(模块 模型钩子;类别 source;类型 dependency-wiring;符号 apply_nemotron_h_defaults): 模型钩子,针对 Nemotron NVFP4 模型自动设置 MoE runner 为 marlin,简化用户配置。
python/sglang/srt/layers/moe/moe_runner/marlin.py(模块 MoE执行器;类别 source;类型 core-logic;符号 MarlinMoERunner): MoE runner 增加对 global_scale 参数的传递,确保 NVFP4 全局 scale 到达 kernel。
python/sglang/srt/layers/quantization/marlin_utils.py(模块 量化工具;类别 source;类型 core-logic;符号 MARLIN_SUPPORTED_GROUP_SIZES, check_marlin_supported): 修改全局 MARLIN_SUPPORTED_GROUP_SIZES,加入 16 以支持 NVFP4 的 group_size=16,同时新增 check 逻辑避免影响其他量化类型。
python/sglang/jit_kernel/utils.py(模块 JIT工具;类别 source;类型 core-logic;符号 _local_jit_source_hash): 改进 JIT 缓存哈希机制,使其包含源代码文件内容,确保重新编译的正确性。减少因头文件变更导致的缓存未命中问题。
python/sglang/test/test_marlin_utils.py(模块 测试工具;类别 test;类型 test-coverage;符号 make_nvfp4_weight_and_ref, _unpack): 新增 make_nvfp4_weight_and_ref 辅助函数,为 NVFP4 单元测试提供构造 NVFP4 权重和 dequant 参考的函数。
python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h(模块 内核模板;类别 source;类型 core-logic): Marlin 内核模板微调,可能涉及 NVFP4 相关的配置,确保 JIT 编译通过。
关键符号:nvfp4_marlin_process_scales, nvfp4_marlin_process_global_scale, apply_fp4_marlin_linear, prepare_nvfp4_layer_for_marlin, prepare_moe_nvfp4_layer_for_marlin, fused_marlin_moe, get_scalar_type, check_marlin_supported
关键源码片段
python/sglang/srt/layers/quantization/marlin_utils_fp4.py
核心新增文件,实现 NVFP4 到 Marlin 的 scale 转换、dense linear apply 函数和权重 prep 逻辑。所有 NVFP4 专用内核适配均在此完成。
# python/sglang/srt/layers/quantization/marlin_utils_fp4.py
def nvfp4_marlin_process_scales(marlin_scales: torch.Tensor) -> torch.Tensor:
# NVFP4 ModelOpt scales 应为非负值,但保留警告以便诊断异常 checkpoint
if not (marlin_scales >= 0).all():
import logging
logging.getLogger(__name__).warning_once(
"NVFP4 Marlin assumes non-negative scales, but negative scales "
"were found. Accuracy may be degraded."
)
# 转换为 FP16 后执行通道顺序重排 (0,2,1,3)
marlin_scales = marlin_scales.to(torch.half)
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
marlin_scales.size(0), -1
)
# 乘以 2^7 并左移 1 位,以 FP16 位模式表达 FP8E4M3 指数
marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
# Marlin kernel 仅消费偶数索引的 scale
return marlin_scales[:, 1::2].contiguous()
@register_custom_op(fake_impl=fake_apply_fp4_marlin_linear)
def apply_fp4_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_global_scale: torch.Tensor,
workspace: torch.Tensor,
size_n: int,
size_k: int,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
if input.dtype not in (torch.float16, torch.bfloat16):
raise RuntimeError("NVFP4 Marlin requires FP16 or BF16 activations.")
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n,)
use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0), n=size_n, k=size_k,
device=input.device, dtype=input.dtype,
)
# 调用 JIT 编译的 Marlin GEMM kernel,传入全局 scale 参数
output = gptq_marlin_gemm(
a=reshaped_x,
c=None,
b_q_weight=weight,
b_scales=weight_scale,
global_scale=weight_global_scale,
b_zeros=None,
g_idx=None,
perm=None,
workspace=workspace,
b_q_type=scalar_types.float4_e2m1f,
size_m=reshaped_x.size(0),
size_n=size_n,
size_k=size_k,
is_k_full=True,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
)
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)
python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py
MoE 核心函数扩展,新增 NVFP4 检测、全局 scale 支持、非门控激活函数。使 fused_marlin_moe 能够处理 NVFP4 MoE 层。
# python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py
def get_scalar_type(
num_bits: int,
has_zp: bool,
scales: Optional[torch.Tensor] = None,
global_scale: Optional[torch.Tensor] = None,
):
from sgl_kernel.scalar_type import scalar_types
# NVFP4 通过 global_scale 区分:若有 global_scale 则认为是 float4_e2m1f
if not has_zp and num_bits == 4 and scales is not None and (
scales.dtype == torch.float8_e8m0fnu or global_scale is not None
):
return scalar_types.float4_e2m1f
if has_zp:
assert num_bits == 4
return scalar_types.uint4
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
# 在 fused_marlin_moe 函数内部检测 NVFP4
is_nvfp4_marlin = (
num_bits == 4
and w1_zeros is None
and w2_zeros is None
and w1_global_scale is not None
and w2_global_scale is not None
)
if is_mxfp4_marlin:
assert hidden_states.dtype == torch.bfloat16, ...
elif not is_nvfp4_marlin:
assert hidden_states.dtype == w1_scale.dtype, ...
# 非门控时 gemm1_n = N 而不是 2*N
gemm1_n = 2 * N if is_gated else N
评论区精华
全局 group size 修改风险
TomerBN-Nvidia 指出在 MARLIN_SUPPORTED_GROUP_SIZES 中加入 16 会影响所有 Marlin 量化类型的 check_marlin_supported,建议添加 bypass。shaunkotek 随后在 check_marlin_supported 中增加了对 group_size 的额外验证,仅在 NVFP4 对应的 float4_e2m1f 时允许 16。
FP4 后端范围限制
TomerBN-Nvidia 提醒 initialize_fp4_gemm_config 中的 elif 条件会误匹配 SM10+/11+ 设备到 Marlin,shaunkotek 修复为 (8, 0) <= capability < (10, 0)。
MoE 非门控断言
TomerBN-Nvidia 建议将 is_nvfp4_marlin 加入 is_mxfp4_marlin 的断言条件。shaunkotek 解释 mxfp4 只支持 BF16、nvfp4 支持 FP16 和 BF16,因此保留分支判断,最终 TomerBN 认可。
测试覆盖建议
b8zhong 要求添加端到端模型测试,shaunkotek 新增了 test_nvidia_nemotron_3_super_nvfp4.py。
代码复用
b8zhong 建议在测试中使用已有的 is_sm90_supported / is_sm80_supported 替代自定义函数,shaunkotek 修改。
- 全局 group size 修改风险 (correctness): shaunkotek 随后在 check_marlin_supported 中增加了对 group_size 的额外验证,仅在 NVFP4 对应的 float4_e2m1f 时允许 16。
- FP4 后端范围限制 (correctness): shaunkotek 修复为 (8, 0) <= capability < (10, 0)。
- MoE 非门控断言 (correctness): 最终 TomerBN 认可当前设计,保留独立分支。
- 测试覆盖建议 (testing): shaunkotek 新增了 test_nvidia_nemotron_3_super_nvfp4.py。
风险与影响
关联脉络
参与讨论