Prhub

#22300 [NVIDIA] Fix FP8 gemm performance with fp16 models (MInimax-M2.5)

原始 PR 作者 trevor-m 合并时间 2026-06-07 10:45 文件变更 3 提交数 4 评论 5 代码增减 +42 / -6

执行摘要

修复 FP8 GEMM 在 fp16 模型上的性能与精度问题

当使用 fp16 激活模型(如 Minimax-M2.5)在 Blackwell 上运行时,存在性能差和精度为零的问题。PR body 指出:即使 ENABLE_JIT_DEEPGEMM 为 true,由于兼容性检查在运行时失败(输出 dtype 为 float16 而非 bfloat16),权重 scale 被转换为 UE8M0 但实际回退到 triton,而 triton 需要 float32 scale 导致额外转换开销。另如果禁用 DeepGEMM,flashinfer trtllm 后端错误地将 fp32 权重当 UE8M0 解释,导致精度为 0。

建议精读。该 PR 展示了如何通过前移运行时兼容性检查来避免代价高昂的运行时回退,设计模式清晰。值得关注的是 should_deepgemm_weight_requant_ue8m0 函数的设计——将兼容性逻辑集中化、参数化,便于后续扩展其他 GEMM 后端。另外 PR body 中提供的性能对比和精度测试数据非常详尽,可作为后续类似问题定位的参考。

讨论亮点

审核评论来自 b8zhong,仅包含 /tag-and-rerun-ci 操作,未对技术方案提出异议。无其他 review 评论,整体讨论量少,方案获得批准。

实现拆解

  1. 修改 should_deepgemm_weight_requant_ue8m0model_loader/utils.py:新增 output_dtypeweight_shape 可选参数。在模型加载阶段提前检查 DeepGEMM 运行时兼容性:要求输出 dtype 为 torch.bfloat16,且 weight 的 N 维度能被 64 整除,K 维度能被 128 整除。仅在满足条件时才允许 weight scale 转换为 UE8M0。

  2. 调整 process_weights_after_loading_block_quantfp8.py:调用 should_deepgemm_weight_requant_ue8m0 时传入当前层的 orig_dtype(即输出 dtype)和 layer.weight.shape,使上述检查在 weight 处理阶段生效,避免运行时回退。

  3. 增强 flashinfer_gemm_w8a8_block_fp8_linear_with_fallbackfp8_utils.py:在 TRTLLM 后端的 fallback 条件中增加对 format_ue8m0 属性检查:当 weight scale 不是 UE8M0 格式时,直接回退到 triton 实现,防止错误使用未转换的 scale。

该 PR 无测试文件或配置变更,仅涉及源码逻辑调整。

文件 模块 状态 重要度
python/sglang/srt/model_loader/utils.py 模型加载 modified 7.93
python/sglang/srt/layers/quantization/fp8.py 量化 modified 5.7
python/sglang/srt/layers/quantization/fp8_utils.py 量化 modified 5.63

关键符号

should_deepgemm_weight_requant_ue8m0 post_load_weights

关键源码片段

python/sglang/srt/model_loader/utils.py data-contract

核心修改文件,重写了 `should_deepgemm_weight_requant_ue8m0` 函数,新增 `output_dtype` 和 `weight_shape` 参数用于提前进行 DeepGEMM 兼容性检查。同时新增 `post_load_weights` 辅助函数以支持模型权重后处理调用。

def should_deepgemm_weight_requant_ue8m0(
    weight_block_size, output_dtype=None, weight_shape=None
):
    """Should we requant fp8 weights into UE8M0 format when loading the model.    When output_dtype or weight_shape are provided, also checks that DeepGEMM
    can actually run this layer at runtime (bf16 output, N%64==0, K%128==0).
    Without these checks, scales would be converted to UE8M0 but the GEMM would
    fall back to triton which expects float32 scales, causing wrong results.
    """
    # 首先检查 DeepGEMM 是否启用且支持 UE8M0 scale
    if not (
        deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
        and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
        and weight_block_size is not None
    ):
        return False
    # 如果指定了输出 dtype,必须为 bfloat16 才能使用 DeepGEMM
    if output_dtype is not None and output_dtype != torch.bfloat16:
        return False
    # 如果指定了 weight shape,检查 N 和 K 维度是否满足 DeepGEMM 对齐要求
    if weight_shape is not None and (
        weight_shape[0] % 64 != 0 or weight_shape[1] % 128 != 0
    ):
        return False
    return True

评论区精华

无实质性讨论 other

审核者 b8zhong 仅执行了 `/tag-and-rerun-ci` 操作,未对代码逻辑提出质疑或讨论。

结论:无技术争议,PR 直接获得批准。 · 已解决

风险与影响

风险较低。主要影响 FP8 量化路径中的 DeepGEMM 兼容性分支。由于兼容性检查前移至加载阶段且基于静态 shape 和 dtype 信息,对未受影响模型的推理路径无副作用。但需注意新增的 format_ue8m0 属性检查依赖于 weight_scale 对象是否正确设置了该属性;若某个自定义量化器未设置此 flag,TRTLLM 后端可能额外回退到 triton,对性能有小幅影响。

对用户:修复 Minimax-M2.5 在 Blackwell 上的精度和性能问题,吞吐量提升 ~10%(根据 benchmark 数据)。对其他 fp16 激活模型同样受益。对系统:无破坏性接口变更,对非 Blackwell 或无 DeepGEMM 场景无影响。对团队:解决了一个棘手的回归问题,降低了后续维护成本。影响范围限定在 FP8 量化模块,程度中等。

核心路径变更 缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论