Prhub

#25655 Feat/add w4a16 moe support to nemotron

原始 PR 作者 shaunkotek 合并时间 2026-06-03 13:42 文件变更 19 提交数 34 评论 22 代码增减 +999 / -61

执行摘要

支持 Nemotron 模型 NVFP4 权重通过 Marlin W4A16 在 SM80-SM90 上推理

Serve NVFP4 modelopt checkpoints (modelopt_fp4) on Ampere/Hopper by routing them through Marlin W4A16 when native FP4 is not supported. (PR Body 原文)

建议精读:该 PR 展示了如何将专有量化格式(NVFP4 ModelOpt)映射到已有 Marlin 内核,包含 scale 转换、非门控 MoE 扩展、多后端路由等设计决策,对于理解 SGLang 的量化抽象层和 MoE 支持有参考价值。关注点:scale 转换的数值正确性、非门控 MoE 的激活函数处理、全局 scale 指数偏移的数学推导。

讨论亮点

全局 group size 修改风险

TomerBN-Nvidia 指出在 MARLIN_SUPPORTED_GROUP_SIZES 中加入 16 会影响所有 Marlin 量化类型的 check_marlin_supported,建议添加 bypass。shaunkotek 随后在 check_marlin_supported 中增加了对 group_size 的额外验证,仅在 NVFP4 对应的 float4_e2m1f 时允许 16。

FP4 后端范围限制

TomerBN-Nvidia 提醒 initialize_fp4_gemm_config 中的 elif 条件会误匹配 SM10+/11+ 设备到 Marlin,shaunkotek 修复为 (8, 0) <= capability < (10, 0)

MoE 非门控断言

TomerBN-Nvidia 建议将 is_nvfp4_marlin 加入 is_mxfp4_marlin 的断言条件。shaunkotek 解释 mxfp4 只支持 BF16、nvfp4 支持 FP16 和 BF16,因此保留分支判断,最终 TomerBN 认可。

测试覆盖建议

b8zhong 要求添加端到端模型测试,shaunkotek 新增了 test_nvidia_nemotron_3_super_nvfp4.py

代码复用

b8zhong 建议在测试中使用已有的 is_sm90_supported / is_sm80_supported 替代自定义函数,shaunkotek 修改。

实现拆解

  1. NVFP4 到 Marlin 的 scale 适配层:在 marlin_utils_fp4.py 中新增 nvfp4_marlin_process_scales 将 per-group scale 从 FP16/BF16 转换为 Marlin 所需的 FP8E4M3 格式;新增 nvfp4_marlin_process_global_scale 将全局 scale 进行指数偏移以匹配内核期望;新增 apply_fp4_marlin_linear 作为主入口,通过 gptq_marlin_gemm 调用 Marlin 内核,支持 FP16/BF16 激活;通过 register_custom_op 注册 fake 实现用于 tracing。

  2. Dense Linear 集成:在 modelopt_quant.pyModelOptNvFp4LinearMethod 中,create_weights 保存 quant_configparams_dtypeprocess_weights_after_loading 检测当前后端是否为 Marlin,若是则调用 prepare_nvfp4_layer_for_marlin 执行权重和 scale 的 repack,并跳过原生的 Blackwell 路径;apply 方法也优先路由到 apply_fp4_marlin_linear。同时 get_min_capability 从 100 降至 80 以允许 SM80+ 运行。

  3. MoE 扩展fused_marlin_moe.py 中扩展 get_scalar_type 接受 global_scale 参数以匹配 NVFP4 的 scale 语义;新增 is_nvfp4_marlin 检测逻辑;fused_marlin_moe 函数增加 w1_global_scalew2_global_scaleactivationis_gated 参数,并据此计算 gemm1_n 和控制激活函数的使用(支持 silu 和 relu2)。MarlinMoERunnermoe_runner/marlin.py 中传递全局 scale。

  4. FP4 GEMM 后端选择fp4_utils.pyFp4GemmRunnerBackend 新增 MARLINinitialize_fp4_gemm_config 在 SM80-SM90 且 CUDA 时自动选择 marlin,并优先于其他后端(flashinfer 等仅限 SM100+ 或 CPU)。

  5. 模型钩子nemotron_h_hook.py 在检测到 modelopt_fp4modelopt_mixed 量化且 SM80-SM90 时,自动设置 --moe-runner-backend marlin,避免用户手动指定。

  6. 测试与文档:新增单元测试验证 scale 变换数值正确性(test_gptq_marlin.py)、dense linear 与 dequant 参考的精度(test_gptq_marlin.py)、MoE 非门控 W4A16 和 NVFP4 的端到端精度(test_moe_wna16_marlin.py)。新增端到端模型测试 test_nvidia_nemotron_3_super_nvfp4.py。更新 docs_new/docs/references/fp4_gemm_backend.mdserver_args 文档。

文件 模块 状态 重要度
python/sglang/srt/layers/quantization/marlin_utils_fp4.py 量化层 modified 8.84
python/sglang/srt/layers/quantization/modelopt_quant.py 量化配置 modified 7.65
python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py MoE 内核 modified 7.45
python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py MoE 测试 modified 7.3
python/sglang/jit_kernel/tests/test_gptq_marlin.py 内核测试 modified 6.73
python/sglang/srt/layers/quantization/fp4_utils.py 后端选择 modified 6.68
python/sglang/srt/arg_groups/nemotron_h_hook.py 模型钩子 modified 6.21
python/sglang/srt/layers/moe/moe_runner/marlin.py MoE 执行器 modified 6.21
python/sglang/srt/layers/quantization/marlin_utils.py 量化工具 modified 6.15
python/sglang/jit_kernel/utils.py JIT 工具 modified 6.71
python/sglang/test/test_marlin_utils.py 测试工具 modified 5.79
python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h 内核模板 modified 5.47

关键符号

nvfp4_marlin_process_scales nvfp4_marlin_process_global_scale apply_fp4_marlin_linear prepare_nvfp4_layer_for_marlin prepare_moe_nvfp4_layer_for_marlin fused_marlin_moe get_scalar_type check_marlin_supported

关键源码片段

python/sglang/srt/layers/quantization/marlin_utils_fp4.py core-logic

核心新增文件,实现 NVFP4 到 Marlin 的 scale 转换、dense linear apply 函数和权重 prep 逻辑。所有 NVFP4 专用内核适配均在此完成。

# python/sglang/srt/layers/quantization/marlin_utils_fp4.pydef nvfp4_marlin_process_scales(marlin_scales: torch.Tensor) -> torch.Tensor:
    # NVFP4 ModelOpt scales 应为非负值,但保留警告以便诊断异常 checkpoint
    if not (marlin_scales >= 0).all():
        import logging
        logging.getLogger(__name__).warning_once(
            "NVFP4 Marlin assumes non-negative scales, but negative scales "
            "were found. Accuracy may be degraded."
        )
​
    # 转换为 FP16 后执行通道顺序重排 (0,2,1,3)
    marlin_scales = marlin_scales.to(torch.half)
    marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
        marlin_scales.size(0), -1
    )
    # 乘以 2^7 并左移 1 位,以 FP16 位模式表达 FP8E4M3 指数
    marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
    marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
    # Marlin kernel 仅消费偶数索引的 scale
    return marlin_scales[:, 1::2].contiguous()
​
​
@register_custom_op(fake_impl=fake_apply_fp4_marlin_linear)
def apply_fp4_marlin_linear(
    input: torch.Tensor,
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    weight_global_scale: torch.Tensor,
    workspace: torch.Tensor,
    size_n: int,
    size_k: int,
    bias: torch.Tensor | None = None,
    use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
    if input.dtype not in (torch.float16, torch.bfloat16):
        raise RuntimeError("NVFP4 Marlin requires FP16 or BF16 activations.")
​
    reshaped_x = input.reshape(-1, input.shape[-1])
    out_shape = input.shape[:-1] + (size_n,)
​
    use_atomic_add = should_use_atomic_add_reduce(
        m=reshaped_x.size(0), n=size_n, k=size_k,
        device=input.device, dtype=input.dtype,
    )
​
    # 调用 JIT 编译的 Marlin GEMM kernel,传入全局 scale 参数
    output = gptq_marlin_gemm(
        a=reshaped_x,
        c=None,
        b_q_weight=weight,
        b_scales=weight_scale,
        global_scale=weight_global_scale,
        b_zeros=None,
        g_idx=None,
        perm=None,
        workspace=workspace,
        b_q_type=scalar_types.float4_e2m1f,
        size_m=reshaped_x.size(0),
        size_n=size_n,
        size_k=size_k,
        is_k_full=True,
        use_atomic_add=use_atomic_add,
        use_fp32_reduce=use_fp32_reduce,
    )
​
    if bias is not None:
        output.add_(bias)
​
    return output.reshape(out_shape)
python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py core-logic

MoE 核心函数扩展,新增 NVFP4 检测、全局 scale 支持、非门控激活函数。使 fused_marlin_moe 能够处理 NVFP4 MoE 层。

# python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.pydef get_scalar_type(
    num_bits: int,
    has_zp: bool,
    scales: Optional[torch.Tensor] = None,
    global_scale: Optional[torch.Tensor] = None,
):
    from sgl_kernel.scalar_type import scalar_types
​
    # NVFP4 通过 global_scale 区分:若有 global_scale 则认为是 float4_e2m1f
    if not has_zp and num_bits == 4 and scales is not None and (
        scales.dtype == torch.float8_e8m0fnu or global_scale is not None
    ):
        return scalar_types.float4_e2m1f
    if has_zp:
        assert num_bits == 4
        return scalar_types.uint4
    else:
        return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
​
​
# 在 fused_marlin_moe 函数内部检测 NVFP4
is_nvfp4_marlin = (
    num_bits == 4
    and w1_zeros is None
    and w2_zeros is None
    and w1_global_scale is not None
    and w2_global_scale is not None
)if is_mxfp4_marlin:
    assert hidden_states.dtype == torch.bfloat16, ...
elif not is_nvfp4_marlin:
    assert hidden_states.dtype == w1_scale.dtype, ...# 非门控时 gemm1_n = N 而不是 2*N
gemm1_n = 2 * N if is_gated else N

评论区精华

全局 group size 修改风险 正确性

TomerBN-Nvidia 指出在 MARLIN_SUPPORTED_GROUP_SIZES 中加入 16 会影响所有 Marlin 量化类型的 check_marlin_supported,建议添加 bypass。

结论:shaunkotek 随后在 check_marlin_supported 中增加了对 group_size 的额外验证,仅在 NVFP4 对应的 float4_e2m1f 时允许 16。 · 已解决

FP4 后端范围限制 正确性

TomerBN-Nvidia 提醒 initialize_fp4_gemm_config 中的 elif 条件会误匹配 SM10+/11+ 设备到 Marlin,建议限制在 < (10, 0)。

结论:shaunkotek 修复为 (8, 0) <= capability < (10, 0)。 · 已解决

MoE 非门控断言 正确性

TomerBN-Nvidia 建议将 is_nvfp4_marlin 加入 is_mxfp4_marlin 的断言条件。shaunkotek 解释 mxfp4 只支持 BF16、nvfp4 支持 FP16 和 BF16,因此保留分支判断。

结论:最终 TomerBN 认可当前设计,保留独立分支。 · 已解决

测试覆盖建议 测试

b8zhong 要求添加端到端模型测试到 sglang/test/manual/models/,以便追踪 Marlin 路径的准确率。

结论:shaunkotek 新增了 test_nvidia_nemotron_3_super_nvfp4.py。 · 已解决

风险与影响

  • 数值精度风险:NVFP4 到 Marlin W4A16 的 scale 转换涉及指数偏移和浮点 reinterpret,可能引入精度损失。已有 dense linear 和 MoE 单元测试,但 CI 中由于 JIT 编译时间过长被 skip,需手动运行验证。
  • 全局影响MARLIN_SUPPORTED_GROUP_SIZES 的修改经 bypass 后影响受控,但仍需警惕其他量化类型意外接受 group size 16。
  • 兼容性:NVFP4 路径仅在 SM80-SM90 启用,SM100+ 仍使用原生 FP4 后端,不会影响已有功能。
  • 性能:Marlin W4A16 相比原生 FP4 可能略慢,但这是必要的 fallback。
  • 用户:Nemotron NVFP4 模型现在可以在 A100/H100 上推理,需指定 --fp4-gemm-backend marlin(或由 nemotron hook 自动选择)。
  • 系统:新增 Marlin 后端,不影响现有 flashinfer/cutlass 后端选择逻辑。
  • 团队:需维护 NVFP4 适配层和测试,CI 中包括端到端模型测试(但被 skip,待解决编译时间问题)。
核心量化路径变更 测试在 CI 中被 skip 全局 group size 兼容性

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论