Prhub

#21321 [Kernel] Support FlashInfer TRTLLM-Gen fused MoE for non-gated FP4 & FP8 (Nemotron)

原始 PR 作者 danielafrimi 合并时间 2026-04-29 16:28 文件变更 8 提交数 10 评论 18 代码增减 +341 / -53

执行摘要

FlashInfer TRTLLM-Gen 融合 MoE 支持非门控 FP4/FP8,加速 NemotronH-120B

NemotronH-120B 模型使用非门控 (relu2) 激活,但 FlashInfer TRTLLM-Gen MoE 内核原仅支持 gated (silu) 激活。为了在 FlashInfer TRTLLM-Gen 后端上运行 NemotronH-120B 并利用其融合 MoE 加速,需要扩展内核支持非门控激活。此外需要处理权重对齐、路由方法等适配。

值得精读。该 PR 清晰展示了如何为 FlashInfer TRTLLM-Gen MoE 后端扩展非门控激活支持,包括权重对齐策略、激活类型传递和自动 backend 选择。设计中的分支权衡和测试取舍也值得关注。建议重点关注 _align_fp8_moe_weights 函数的对齐逻辑和 activation_type 参数传递链。

讨论亮点
  • 门控判断分支设计:Fridge003 担心在 align_fp8_moe_weights_for_flashinfer_trtllm 中增加 if is_gated 分支会破坏现有逻辑,建议创建独立函数。作者未直接回复,最终仍采用分支方式,但通过统一对齐函数 _align_fp8_moe_weights 降低了复杂度。
  • 函数重复:Fridge003 指出 round_up_to_multipleutils.py 中已存在局部定义,不应重复。作者已修复,改为统一引用。
  • 测试覆盖:Fridge003 认为新增的单元测试文件(test_trtllm_moe_non_gated.py)和部分 e2e 测试非必需,建议删除。最终这些测试被移除,依赖已有 4-gpu 模型测试覆盖该特性。

实现拆解

  1. 添加基础工具函数:在 flashinfer_trtllm.py 中新增 round_up_to_multiple_is_gated_align_fp8_moe_weights_align_mxfp8_moe_weights_align_fp4_moe_weights 等辅助函数,用于权重对齐填充和门控检测。

  2. 修改权重准备函数:对 align_fp8_moe_weights_for_flashinfer_trtllm 增加门控判断,为非门控模型执行不同对齐(padding 到 128 边界),并跳过 gated 重排。类似调整 align_fp4_moe_weights_for_flashinfer_trtllmalign_mxfp8_moe_weights_for_flashinfer_trtllm

  3. 添加激活类型传递:在 flashinfer_trtllm_moe.py 的所有包装函数(trtllm_fp8_block_scale_moe_wrapper 等)中添加 activation_type 参数,并将其转换为 FlashInfer 的 ActivationType 枚举传递给底层内核。

  4. 集成到量化方法:在 modelopt_quant.pyfp8.pycompressed_tensorsnemotron_h.py 中,将激活断言从仅 "silu" 放宽为 {"silu", "relu2"},并通过 get_activation_type 将配置转换为枚举,填充到 quant_info 中。同时将 routing_method_type 改为从 layer 属性获取(默认为 DeepSeekV3 用于 NemotronH)。

  5. 自动 backend 选择:在 server_args.py 中,当 moe_runner_backend"auto" 且架构支持 SM100 且无其他 MoE 通信后端时,自动选择 flashinfer_trtllm 而非默认的 flashinfer_cutlass

  6. FP4 权重准备适配:在 quantization/utils.pyprepare_static_weights_for_trtllm_fp4_moe 中增加 is_gated 参数,控制行数计算和 permute 索引传递,使非门控 FP4 权重能被正确转换。

测试方面:review 中要求删除单元测试,最终合并未包含新测试文件,依赖现有 4-gpu 模型测试保护。

文件 模块 状态 重要度
python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py MoE 执行器 modified 8.84
python/sglang/srt/layers/quantization/modelopt_quant.py 量化方法 modified 6.74
python/sglang/srt/layers/moe/flashinfer_trtllm_moe.py 内核封装 modified 6.32
python/sglang/srt/layers/quantization/utils.py 量化工具 modified 6.04
python/sglang/srt/server_args.py 服务配置 modified 5.91
python/sglang/srt/layers/quantization/fp8.py FP8 量化 modified 5.4
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8_moe.py 压缩张量 modified 5.35
python/sglang/srt/models/nemotron_h.py 模型定义 modified 5.27

关键符号

round_up_to_multiple _is_gated _align_fp8_moe_weights _align_mxfp8_moe_weights _align_fp4_moe_weights get_activation_type align_fp8_moe_weights_for_flashinfer_trtllm align_fp4_moe_weights_for_flashinfer_trtllm prepare_static_weights_for_trtllm_fp4_moe

关键源码片段

python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py core-logic

核心变更文件,实现非门控支持的核心逻辑:权重对齐、门控检测、激活类型映射等。所有函数均在此文件内新增或修改。

def _is_gated(layer: Module) -> bool:
    """判断 MoE 层是否使用门控激活(默认 True,保持向后兼容)。"""
    is_gated = (
        getattr(layer, "moe_runner_config", None) and layer.moe_runner_config.is_gated
    )
    return True if is_gated is None else is_gated # 若未设置,视为 gated
​
​
def round_up_to_multiple(x: int, m: int) -> int:
    """将 x 向上取整到 m 的倍数(用于对齐)。"""
    return (x + m - 1) // m * m
​
​
def _align_fp8_moe_weights(
    w13: torch.Tensor,
    w2: torch.Tensor,
    is_gated: bool,
    min_alignment: int = 16,
) -> tuple[torch.Tensor, torch.Tensor, int]:
    """对齐 intermediate 维度以满足 FlashInfer TRTLLM FP8 内核要求。    非门控模型需要 128 对齐,门控模型需要 16 对齐。
    返回 (w13, w2, padded_intermediate)。
    """
    num_experts, hidden_size, intermediate = w2.shape
    padded_intermediate = round_up_to_multiple(intermediate, min_alignment)
    if padded_intermediate == intermediate:
        return w13, w2, intermediate
​
    logger.info(
        "FP8 MoE: padding intermediate size from %d to %d (alignment=%d)",
        intermediate, padded_intermediate, min_alignment,
    )
​
    # 非门控时只有 up 投影,没有 gate;门控时同时有 gate 和 up
    up_mult = 2 if is_gated else 1
    padded_gate_up = up_mult * padded_intermediate
​
    padded_w13 = w13.new_zeros((num_experts, padded_gate_up, w13.shape[2]))
    padded_w13[:, : w13.shape[1], :] = w13
​
    padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
    padded_w2[:, :, :intermediate] = w2
​
    return padded_w13, padded_w2, padded_intermediate
​
​
def align_fp8_moe_weights_for_flashinfer_trtllm(layer: Module, swap_w13_halves: bool = False) -> None:
    """准备 FP8 MoE 权重/缩放因子用于 FlashInfer TRTLLM 内核。"""
    from flashinfer import shuffle_matrix_a
​
    is_gated = _is_gated(layer)
    w13_weight = cast(torch.Tensor, layer.w13_weight)
    w2_weight = cast(torch.Tensor, layer.w2_weight)
    num_experts, gate_up_dim, hidden = w13_weight.shape
​
    # 仅门控模型且 swap 标志启用时交换 W13 半部分 [Up, Gate] -> [Gate, Up]
    if swap_w13_halves and is_gated:
        inter = gate_up_dim // 2
        w13_weight = (
            w13_weight.reshape(num_experts, 2, inter, hidden)
            .flip(dims=[1])
            .reshape(num_experts, gate_up_dim, hidden)
        )
​
    # 非门控需要 128 对齐,门控需要 16 对齐
    min_alignment = 16 if is_gated else 128
    w13_weight, w2_weight, _ = _align_fp8_moe_weights(
        w13_weight, w2_weight, is_gated, min_alignment
    )
    num_experts, gate_up_dim, hidden = w13_weight.shape
​
    # 仅门控需要调用 reorder_rows_for_gated_act_gemm
    if is_gated:
        from flashinfer import reorder_rows_for_gated_act_gemm
        w13_interleaved_list = [
            reorder_rows_for_gated_act_gemm(w13_weight[i]) for i in range(num_experts)
        ]
        w13_processed = torch.stack(w13_interleaved_list).reshape(
            num_experts, gate_up_dim, hidden
        )
    else:
        w13_processed = w13_weight
​
    # 后续 shuffle 和缩放因子计算与 gated 模型一致 ...

评论区精华

非门控对齐使用 if-else 分支设计风险 设计

Fridge003 在 line 106 评论:'Can we create new functions for aligning weights on non-gated models? Adding if-else like this can easily break current logic'

结论:最终保留 if-else 分支,但通过统一的 `_align_fp8_moe_weights` 函数封装逻辑,降低风险。作者未明确回复,但代码中保持分支方式。 · 已解决

round_up_to_multiple 函数重复 style

Fridge003 在 utils.py line 581 指出重复定义,应使用已有 `round_up_to_multiple`。作者回复 'fixed'。

结论:修复:删除 utils.py 中的局部 lambda,改为统一使用 `round_up_to_multiple`。 · 已解决

测试覆盖非门控功能 测试

Fridge003 在 test/registered/kernels/test_trtllm_moe_non_gated.py 和 test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py 评论说不需要这些测试,已有 4-gpu 模型测试可保护。

结论:删除这两处测试文件,依赖现有 e2e 测试覆盖非门控 MoE 功能。 · 已解决

风险与影响

  • 回归风险:对 align_fp8_moe_weights_for_flashinfer_trtllm 等核心函数添加门控分支可能影响现有 gated 模型(silu)的行为,若分支判断失误或对齐参数错误可能导致数值偏差或崩溃。
  • 测试不足:最终合并无专用单元测试覆盖非门控路径,仅依赖 e2e 测试可能遗漏边界情况(如 TP>1 时的对齐 padding、不同 intermediate 尺寸组合)。
  • 数值兼容性:权重对齐 padding 会改变中间表示维度,可能影响缩放因子计算和最终输出数值,需确保与 TensorRT-LLM 内核的预期一致。
  • 参数一致性activation_type 参数需要在多个文件(flashinfer_trtllm_moe.pymodelopt_quant.pyfp8.pycompressed_tensors)中同步传递,任何遗漏都会导致内核调用错误。
  • 用户:启用 NemotronH-120B 模型在 FlashInfer TRTLLM-Gen 后端上的运行,获得显著性能提升(1.25-1.85x)。自动 backend 选择对 SM100 用户透明。
  • 系统:仅在 SM100(B200)架构上自动启用 flashinfer_trtllm,不影响其他 GPU 平台。引入的 alignment padding 会小幅增加显存占用(padding 后的中间尺寸)。
  • 团队:增加了 MoE 内核的维护成本,未来添加新激活类型需同步修改多处;但工具函数(如 _align_fp8_moe_weights)提高了代码复用性。
核心路径变更 缺少专项测试 数值兼容性

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论