执行摘要
- 一句话:MXFP8 MoE Group GEMM & Quant 迁移至 JIT,Blackwell 加速
- 推荐动作:值得精读,尤其是 Python 端 JIT 编译集成模式(cache_once + load_jit)和 MoE Group GEMM 的 CUTLASS 实现。2SM 策略在 memory-bound 场景的加速效果值得关注。作为系列 PR 的第 1 部分,建议跟踪后续集成。
功能与动机
现有的 sgl-kernel 中的 MXFP8 操作是预编译的,缺乏灵活性。迁移到 JIT 可以利用 TVM 动态编译适配当前 GPU,并允许采用更高效的 MoE Group GEMM 替代普通 Group GEMM,减少 TMA 更新并提高性能和内存受限场景的吞吐量。
实现拆解
-
在 python/sglang/jit_kernel/mxfp8.py 中封装 JIT 编译逻辑:定义 _mxfp8_cuda_flags() 和 _mxfp8_arch_env() 设置 NVCC 参数(启用 CUDA 12.8 sm_100a)。通过 @cache_once 的 _jit_es_sm100_mxfp8_blockscaled_group_quant 和 _jit_es_sm100_mxfp8_blockscaled_moe_group_gemm 使用 load_jit 加载 .cuh 源文件并编译为 TVM Module。对外提供两个高层函数 es_sm100_mxfp8_blockscaled_grouped_quant 和 es_sm100_mxfp8_blockscaled_moe_grouped_gemm,自动调用编译好的内核。
-
在 python/sglang/jit_kernel/csrc/moe/expert_specialization/ 下实现 CUDA 内核:
es_sm100_mxfp8_blockscaled_group_quant.cuh:按 128 元素块计算 max 并导出 fp8 和 scale factor。
es_sm100_mxfp8_blockscaled_moe_group_gemm.cuh:使用 CUTLASS GemmUniversalAdapter,以 MoEProblemShape 驱动 grouped GEMM,先通过预计算核生成指针数组,再运行主 GEMM。
es_sm100_mxfp8_blockscaled_moe_group_gemm_traits.cuh:定义 2SM 和 1SM 等 MMA 配置。
es_sm100_mxfp8_blockscaled_moe_group_gemm_functor.cuh:偏移计算 functor。
-
在 python/sglang/jit_kernel/tests/test_mxfp8_moe.py 添加单元测试:参数化 num_experts 和输出类型,通过 calc_diff 与 fp32 参考对比,精度要求 diff < 0.001。使用 register_cuda_ci 注册到 CI 套件。
-
在 python/sglang/jit_kernel/benchmark/bench_mxfp8_moe.py 添加基准测试,与 sgl-kernel 版本对比延迟。
-
通过 CI 注册确保 GPU 资源分配;lint 和格式调整。
关键文件:
python/sglang/jit_kernel/mxfp8.py(模块 JIT 封装;类别 source;类型 dependency-wiring;符号 es_sm100_mxfp8_blockscaled_grouped_quant, es_sm100_mxfp8_blockscaled_moe_grouped_gemm, _jit_es_sm100_mxfp8_blockscaled_group_quant, _jit_es_sm100_mxfp8_blockscaled_moe_group_gemm): JIT 封装入口,提供量化和 GEMM 的高层接口,串联所有 CUDA 内核。
python/sglang/jit_kernel/tests/test_mxfp8_moe.py(模块 单元测试;类别 test;类型 test-coverage;符号 test_es_sm100_mxfp8_blockscaled_grouped_mm, align, calc_diff, is_sm100_supported): 单元测试验证精度,覆盖多 expert 和输出类型。
python/sglang/jit_kernel/benchmark/bench_mxfp8_moe.py(模块 基准测试;类别 source;类型 dependency-wiring;符号 is_sm100_supported, _probe_sgl_kernel_group_mm, _prepare_case, benchmark): 基准测试比较 JIT 和 sgl-kernel 延迟。
python/sglang/jit_kernel/csrc/moe/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh(模块 CUDA 内核;类别 other;类型 dependency-wiring): 量化的 CUDA 内核实现,使用 CUTLASS 块缩放方法。
python/sglang/jit_kernel/csrc/moe/expert_specialization/es_sm100_mxfp8_blockscaled_moe_group_gemm.cuh(模块 CUDA 内核;类别 other;类型 dependency-wiring): MoE 组 GEMM 内核,包含预计算核和 grouped gemm 启动。
关键符号:es_sm100_mxfp8_blockscaled_grouped_quant, es_sm100_mxfp8_blockscaled_moe_grouped_gemm, _jit_es_sm100_mxfp8_blockscaled_group_quant, _jit_es_sm100_mxfp8_blockscaled_moe_group_gemm, _prepare_case, benchmark, test_es_sm100_mxfp8_blockscaled_grouped_mm
关键源码片段
python/sglang/jit_kernel/mxfp8.py
JIT 封装入口,提供量化和 GEMM 的高层接口,串联所有 CUDA 内核。
# JIT 编译函数(每个 dtype 只编译一次)
@cache_once
def _jit_es_sm100_mxfp8_blockscaled_group_quant(dtype: torch.dtype) -> Module:
args = make_cpp_args(dtype)
with _mxfp8_arch_env():
return load_jit(
'es_sm100_mxfp8_blockscaled_group_quant',
*args,
cuda_files=['moe/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh'],
cuda_wrappers=[('es_sm100_mxfp8_blockscaled_group_quant',
f'EsSm100MXFP8BlockscaledGroupQuant<{args}>::run')],
extra_dependencies=['cutlass'],
extra_cuda_cflags=_mxfp8_cuda_flags(),
)
@cache_once
def _jit_es_sm100_mxfp8_blockscaled_moe_group_gemm(dtype: torch.dtype) -> Module:
args = make_cpp_args(dtype)
with _mxfp8_arch_env():
return load_jit(
'es_sm100_mxfp8_blockscaled_moe_group_gemm',
*args,
cuda_files=['moe/expert_specialization/es_sm100_mxfp8_blockscaled_moe_group_gemm.cuh'],
cuda_wrappers=[('es_sm100_mxfp8_blockscaled_moe_group_gemm',
f'EsSm100MXFP8BlockscaledMoeGroupGemm<{args}>::run')],
extra_dependencies=['cutlass'],
extra_cuda_cflags=_mxfp8_cuda_flags(),
)
def es_sm100_mxfp8_blockscaled_grouped_quant(
input: torch.Tensor,
tokens_per_expert: torch.Tensor,
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
quant_output: torch.Tensor,
scale_factor: torch.Tensor,
) -> None:
# 对拼接的 MoE 输入执行块缩放 MXFP8 量化
module = _jit_es_sm100_mxfp8_blockscaled_group_quant(input.dtype)
module.es_sm100_mxfp8_blockscaled_group_quant(
input, tokens_per_expert, expert_offsets,
blockscale_offsets, quant_output, scale_factor,
)
def es_sm100_mxfp8_blockscaled_moe_grouped_gemm(
a: torch.Tensor,
b: torch.Tensor,
sfa: torch.Tensor,
sfb: torch.Tensor,
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
tokens_per_expert: torch.Tensor,
workspace: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
# MoE 组 GEMM:针对每个专家独立执行矩阵乘法,使用 MXFP8 输入和输出
num_experts, m, tokens = a.shape[0], a.shape[1], b.shape[0]
d = torch.empty((tokens, m), device=a.device, dtype=dtype)
# 设备端指针数组(每次调用分配,review 指出可优化)
d_ptrs = torch.empty((num_experts,), device=a.device, dtype=torch.int64)
b_ptrs = torch.empty((num_experts,), device=a.device, dtype=torch.int64)
sfb_ptrs = torch.empty((num_experts,), device=a.device, dtype=torch.int64)
module = _jit_es_sm100_mxfp8_blockscaled_moe_group_gemm(dtype)
module.es_sm100_mxfp8_blockscaled_moe_group_gemm(
a, b, sfa, sfb, expert_offsets, blockscale_offsets,
tokens_per_expert, b_ptrs, sfb_ptrs, d, d_ptrs, workspace,
)
return d
评论区精华
风险与影响
- 风险:
- 硬件兼容性:仅支持 SM100+(Blackwell)且需要 CUDA 12.8+,旧硬件将抛出异常。
- 性能风险:JIT 编译首次调用有额外开销,但通过 @cache_once 只发生一次。
- 正确性:测试覆盖 8/16/32/64 expert 和随机 m(1-512),diff < 0.001,但未覆盖所有边界(如 m > 512)。
- 维护复杂性:CUDA 内核依赖 CUTLASS 模板,代码易于出错且调试困难。
- 热路径指针分配优化未解决:每次 GEMM 调用分配设备指针数组,可能影响频繁调用场景。
- 影响:
- 用户:Blackwell GPU 且使用 MXFP8 量化的用户可获得 5-15% 延迟降低;其他硬件用户不受影响。
- 系统:扩展 jit_kernel 模块,新增约 1400 行代码,需维护 CUTLASS 依赖。
- 团队:新增 JIT 内核开发和维护成本,但提供更灵活的优化路径;后续 PR 将集成此内核到 MoE 推理流程。
- 风险标记:Blackwell 独占, 热路径分配问题
关联脉络
参与讨论