执行摘要
- 一句话:支持MoRI EP的FP8 blockwise量化combine
- 推荐动作:值得精读。展示了如何用枚举替换布尔标志提升可扩展性,以及如何与外部库协作安全引入新量化模式。尤其适合关注 AMD 平台性能优化的工程师。
功能与动机
Issue #24866 报告了在启用FP8 combine时GSM8K精度下降的问题,原因是缺少正确的量化校正。此PR通过集成MoRI的FP8 blockwise量化combine来解决,并配合MoRI PR #311的上游实现。
实现拆解
- 引入枚举类型:在
moriep.py 中定义 DispatchDtype(bf16/fp8/fp4)和 CombineDtype(bf16/fp8/fp8_direct_cast),替换原有的布尔标志 fp8_dispatch、fp4_dispatch,使 dtype 配置更加可扩展且类型安全。
- 修改
init_mori_op 函数:将参数从布尔改为枚举类型 dispatch_dtype 和 combine_dtype;在 combine dtype 为 fp8 时设置 combine_quant_type = "fp8_blockwise",否则保持原有逻辑。
- 环境变量支持与向后兼容:新增
SGLANG_MORI_COMBINE_DTYPE(auto/bf16/fp8/fp8_direct_cast)控制 combine dtype;统一 SGLANG_MORI_DISPATCH_DTYPE(auto/bf16/fp8/fp4)并弃用旧的 SGLANG_MORI_FP8_DISP/SGLANG_MORI_FP4_DISP;对弃用变量显示警告。
- 块大小常量与 scale_dim 计算:在文件顶部定义
FP8_BLOCK_SIZE = 128 和 MXFP4_BLOCK_SIZE = 32;在 init_mori_op 中用这些常量计算 scale_dim,取代魔法数字。
- Dockerfile 更新:将
docker/rocm.Dockerfile 中的 MORI_COMMIT 从 v1.1.1 更新为包含 FP8 blockwise combine 支持的特定 commit。
关键文件:
python/sglang/srt/layers/moe/token_dispatcher/moriep.py(模块 调度器;类别 source;类型 core-logic;符号 DispatchDtype, CombineDtype): 主要实现文件,引入 DispatchDtype、CombineDtype 枚举,修改 init_mori_op 参数和环境变量处理,增加块大小常量。
docker/rocm.Dockerfile(模块 部署脚本;类别 infra;类型 infrastructure): 更新 MoRI 版本以包含 FP8 blockwise combine kernel 支持。
关键符号:init_mori_op, CombineeDtype, DispatchDtype
关键源码片段
python/sglang/srt/layers/moe/token_dispatcher/moriep.py
主要实现文件,引入 DispatchDtype、CombineDtype 枚举,修改 init_mori_op 参数和环境变量处理,增加块大小常量。
# 块大小常量:每组共享一个 scale 的元素数
FP8_BLOCK_SIZE = 128
MXFP4_BLOCK_SIZE = 32
class DispatchDtype(Enum):
"""Dispatch 的量化类型枚举。"""
bf16 = "bfloat16"
fp8 = "float8_blockwise"
fp4 = "mxfp4_blockwise"
class CombineDtype(Enum):
"""Combine 的量化类型枚举。"""
bf16 = "bfloat16"
fp8 = "float8_blockwise"
fp8_direct_cast = "float8_direct_cast"
@lru_cache(maxsize=4)
def init_mori_op(
group,
router_topk,
num_experts,
num_local_experts,
hidden_size,
params_dtype,
num_max_dispatch_tokens_per_rank,
deepep_mode,
instance_id=0,
# 之前是 fp8_dispatch=False, fp4_dispatch=False
dispatch_dtype=DispatchDtype.bf16,
combine_dtype=CombineDtype.bf16,
enable_sdma=False,
):
# ... 其他代码 ...
# 根据 dispatch_dtype 计算 scale_dim
if dispatch_dtype == DispatchDtype.fp8:
scale_dim = hidden_size // FP8_BLOCK_SIZE
elif dispatch_dtype == DispatchDtype.fp4:
# FP4 kernel 需要保持原始 hidden_size,内部做量化
hidden_dim = hidden_size
scale_dim = hidden_size // MXFP4_BLOCK_SIZE
data_type = torch.float4_e2m1fn_x2
scale_type_size = torch.float8_e8m0fnu.itemsize
# ...
# 处理 combine_quant_type
combine_quant_type = "none"
if combine_dtype == CombineDtype.fp8:
combine_quant_type = "fp8_blockwise"
elif combine_dtype == CombineDtype.fp8_direct_cast:
combine_quant_type = "fp8_direct_cast"
# ...
评论区精华
HaiShaw 在代码审查中要求为使用的块大小添加注释。billishyahao 回应已添加注释,并解释块大小由 MoRI 内部处理,对 SGLang 端不可见。该讨论已解决。
- 块大小注释要求 (style): billishyahao 添加了注释并解释块大小由 MoRI 内部处理,对 SGLang 不可见。
风险与影响
- 风险:
- 回归风险:枚举替换可能导致旧的布尔参数配置失效,但提供了向后兼容的 env var 并保留弃用警告,风险可控。
- 性能影响:根据 PR body 表格,
fp8_blockwise combine 的吞吐量较 bf16 略低(如 fp4+fp8_blockwise: 784 tps vs fp4+bf16: 848 tps),但精度提高约2%。用户需权衡速度和精度。
- 外部依赖风险:依赖 MoRI 特定 commit,若上游更新可能需同步,但 CI 构建会验证。
- 测试覆盖风险:缺少单元测试,仅依赖手动基准测试(16 组合的 GSM8K 精度),回归检测能力较弱。
- 影响:
- 用户影响:AMD GPU 用户可通过环境变量选择 combine dtype,在精度敏感场景获得高达 94.5% 的 GSM8K 准确率(对比之前 ~91%)。对现有配置无破坏性变更。
- 系统影响:改动集中在
moriep.py(+77/-36),Dockerfile 一行变更;未涉及核心推理路径或跨模块接口。
- 团队维护成本:增加了需要跟踪 MoRI 上游的依赖,但枚举化降低了后续添加新 dtype 的复杂性。
- 风险标记:核心路径变更, 缺少测试覆盖, 外部依赖变更, 性能权衡
关联脉络
参与讨论