Prhub

#24879 [AMD] support fp8 blockwise quantization combine for mori ep

原始 PR 作者 billishyahao 合并时间 2026-05-13 14:24 文件变更 2 提交数 8 评论 3 代码增减 +78 / -37

执行摘要

支持 MoRI EP 的 FP8 blockwise 量化 combine

Issue #24866 报告了在启用FP8 combine时GSM8K精度下降的问题,原因是缺少正确的量化校正。此PR通过集成MoRI的FP8 blockwise量化combine来解决,并配合MoRI PR #311的上游实现。

值得精读。展示了如何用枚举替换布尔标志提升可扩展性,以及如何与外部库协作安全引入新量化模式。尤其适合关注 AMD 平台性能优化的工程师。

讨论亮点

HaiShaw 在代码审查中要求为使用的块大小添加注释。billishyahao 回应已添加注释,并解释块大小由 MoRI 内部处理,对 SGLang 端不可见。该讨论已解决。

实现拆解

  1. 引入枚举类型:在 moriep.py 中定义 DispatchDtype(bf16/fp8/fp4)和 CombineDtype(bf16/fp8/fp8_direct_cast),替换原有的布尔标志 fp8_dispatchfp4_dispatch,使 dtype 配置更加可扩展且类型安全。
  2. 修改 init_mori_op 函数:将参数从布尔改为枚举类型 dispatch_dtypecombine_dtype;在 combine dtype 为 fp8 时设置 combine_quant_type = "fp8_blockwise",否则保持原有逻辑。
  3. 环境变量支持与向后兼容:新增 SGLANG_MORI_COMBINE_DTYPE(auto/bf16/fp8/fp8_direct_cast)控制 combine dtype;统一 SGLANG_MORI_DISPATCH_DTYPE(auto/bf16/fp8/fp4)并弃用旧的 SGLANG_MORI_FP8_DISP/SGLANG_MORI_FP4_DISP;对弃用变量显示警告。
  4. 块大小常量与 scale_dim 计算:在文件顶部定义 FP8_BLOCK_SIZE = 128MXFP4_BLOCK_SIZE = 32;在 init_mori_op 中用这些常量计算 scale_dim,取代魔法数字。
  5. Dockerfile 更新:将 docker/rocm.Dockerfile 中的 MORI_COMMITv1.1.1 更新为包含 FP8 blockwise combine 支持的特定 commit。
文件 模块 状态 重要度
python/sglang/srt/layers/moe/token_dispatcher/moriep.py 调度器 modified 7.97
docker/rocm.Dockerfile 部署脚本 modified 2.64

关键符号

init_mori_op CombineeDtype DispatchDtype

关键源码片段

python/sglang/srt/layers/moe/token_dispatcher/moriep.py core-logic

主要实现文件,引入 DispatchDtype、CombineDtype 枚举,修改 init_mori_op 参数和环境变量处理,增加块大小常量。

# 块大小常量:每组共享一个 scale 的元素数
FP8_BLOCK_SIZE = 128
MXFP4_BLOCK_SIZE = 32
​
​
class DispatchDtype(Enum):
    """Dispatch 的量化类型枚举。"""
    bf16 = "bfloat16"
    fp8 = "float8_blockwise"
    fp4 = "mxfp4_blockwise"
​
​
class CombineDtype(Enum):
    """Combine 的量化类型枚举。"""
    bf16 = "bfloat16"
    fp8 = "float8_blockwise"
    fp8_direct_cast = "float8_direct_cast"
​
​
@lru_cache(maxsize=4)
def init_mori_op(
    group,
    router_topk,
    num_experts,
    num_local_experts,
    hidden_size,
    params_dtype,
    num_max_dispatch_tokens_per_rank,
    deepep_mode,
    instance_id=0,
    # 之前是 fp8_dispatch=False, fp4_dispatch=False
    dispatch_dtype=DispatchDtype.bf16,
    combine_dtype=CombineDtype.bf16,
    enable_sdma=False,
):
    # ... 其他代码 ...
    # 根据 dispatch_dtype 计算 scale_dim
    if dispatch_dtype == DispatchDtype.fp8:
        scale_dim = hidden_size // FP8_BLOCK_SIZE
    elif dispatch_dtype == DispatchDtype.fp4:
        # FP4 kernel 需要保持原始 hidden_size,内部做量化
        hidden_dim = hidden_size
        scale_dim = hidden_size // MXFP4_BLOCK_SIZE
        data_type = torch.float4_e2m1fn_x2
        scale_type_size = torch.float8_e8m0fnu.itemsize
    # ...
    # 处理 combine_quant_type
    combine_quant_type = "none"
    if combine_dtype == CombineDtype.fp8:
        combine_quant_type = "fp8_blockwise"
    elif combine_dtype == CombineDtype.fp8_direct_cast:
        combine_quant_type = "fp8_direct_cast"
    # ...

评论区精华

块大小注释要求 style

HaiShaw 在审查中要求为使用的块大小添加注释。

结论:billishyahao 添加了注释并解释块大小由 MoRI 内部处理,对 SGLang 不可见。 · 已解决

风险与影响

  • 回归风险:枚举替换可能导致旧的布尔参数配置失效,但提供了向后兼容的 env var 并保留弃用警告,风险可控。
  • 性能影响:根据 PR body 表格,fp8_blockwise combine 的吞吐量较 bf16 略低(如 fp4+fp8_blockwise: 784 tps vs fp4+bf16: 848 tps),但精度提高约2%。用户需权衡速度和精度。
  • 外部依赖风险:依赖 MoRI 特定 commit,若上游更新可能需同步,但 CI 构建会验证。
  • 测试覆盖风险:缺少单元测试,仅依赖手动基准测试(16 组合的 GSM8K 精度),回归检测能力较弱。
  • 用户影响:AMD GPU 用户可通过环境变量选择 combine dtype,在精度敏感场景获得高达 94.5% 的 GSM8K 准确率(对比之前 ~91%)。对现有配置无破坏性变更。
  • 系统影响:改动集中在 moriep.py(+77/-36),Dockerfile 一行变更;未涉及核心推理路径或跨模块接口。
  • 团队维护成本:增加了需要跟踪 MoRI 上游的依赖,但枚举化降低了后续添加新 dtype 的复杂性。
核心路径变更 缺少测试覆盖 外部依赖变更 性能权衡

关联 Issue

#24866 [Bug] SGLang Integration with MoRI-EP IntraNode FP8 Combine Accuracy Failing due to Missing Quant Correction

完整报告

参与讨论