Prhub

#26473 [MoE] Support BF16 standard A2A with DeepGEMM runner

原始 PR 作者 popsiclexu 合并时间 2026-06-02 11:40 文件变更 5 提交数 4 评论 5 代码增减 +45 / -26

执行摘要

修复 DeepGEMM runner 中 BF16 A2A 和专家 0 遗漏

The main motivation of this PR is to fix the DeepGEMM MoE runner path when --moe-runner-backend deep_gemm is used with the standard MoE A2A backend. The primary issue is that post_reorder_triton_kernel did not handle expert 0 when combining routed outputs. As a result, routed outputs from expert 0 were skipped in the DeepGEMM runner combine step, which affected model accuracy. This PR also fixes BF16 model support for the same backend combination. With BF16 experts and --moe-runner-backend deep_gemm, the standard A2A to DeepGEMM preprocess path could produce FP8 activations while the runner selected the BF16 masked DeepGEMM path, failing during CUDA graph capture.

建议精读。该 PR 解决了实际运行中的关键问题,并展示了在 Triton kernel 中如何安全地提升数值精度(FP32 累积)。设计决策值得参考,尤其是条件量化路径的选择。如果团队在使用 DeepGEMM 运行时,建议尽快合并此 PR 并做回归验证。

讨论亮点
  • 审查者 BBuf 指出 DeepGEMM 仓库的 PR #37 新增了 BF16 contiguous API,但参数名改为 grouped_layout,而 SGLang 的预热代码使用 m_indices= 关键字。作者移除关键字后,BBuf 确认兼容性没问题。
  • gemini-code-assist 建议 _silu_and_mul_kernelup 也应显式转为 tl.float32 再与 gate 相乘,避免 Triton 编译错误。该建议已被采纳,实际代码中 up 已使用 .to(tl.float32) 处理。
  • 无其他争议,BBuf 和 Fridge003 均给予了批准。

实现拆解

  1. 修复专家0遗漏:在python/sglang/srt/layers/moe/ep_moe/kernels.pypost_reorder_triton_kernel函数中,将专家ID过滤条件从expert_id > 0改为expert_id >= 0,确保排序索引为0的专家被正常加权组合。这是导致精度下降的直接原因。
  2. 提升数值精度:在同一个文件的两个Triton kernel中引入FP32累积:_silu_and_mul_kernelgateup都先转为tl.float32计算SiLU,再转回输入dtype存储;post_reorder_triton_kernel的加权和累加器从InDtype改为tl.float32,权重和输入都先转fp32,最后再转回输出dtype。这减少了低精度下的舍入误差。
  3. 支持BF16 A2A:在python/sglang/srt/layers/moe/moe_runner/deep_gemm.pypre_permute_standard_to_deep_gemm函数中,新增output_dtype推导逻辑:若quant_info.w13_weight.dtype为bf16,则输出torch.bfloat16;否则保留FP8。该输出dtype作为参数传入moe_ep_deepgemm_preprocess。在kernels.py的预处理函数中,依据is_fp8(由output_dtype决定)条件性地进行FP8逐token缩放量化,在BF16模式下直接跳过量化步骤。
  4. 修复BF16 grouped contiguous warmup:在python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py_BF16GroupedContWarmupExecutor.execute中,将deep_gemm.m_grouped_bf16_gemm_nt_contiguous调用改为使用位置参数self.m_indices[:m]而非关键字参数m_indices=self.m_indices[:m],以兼容sgl-deep-gemm 0.1.1的API变化。
  5. 升级依赖与镜像:将sgl-deep-gemm版本从0.1.0提升至0.1.1,同步更新python/pyproject.tomldocker/Dockerfile中的版本号。0.1.1提供了BF16 masked grouped GEMM的必要支持。
文件 模块 状态 重要度
python/sglang/srt/layers/moe/ep_moe/kernels.py MoE 内核 modified 7.07
python/sglang/srt/layers/moe/moe_runner/deep_gemm.py DeepGEMM 适配 modified 6.04
python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py 编译预热 modified 4.89
python/pyproject.toml 依赖配置 modified 3.07
docker/Dockerfile 部署脚本 modified 2.38

关键符号

post_reorder_triton_kernel _silu_and_mul_kernel moe_ep_deepgemm_preprocess fill_gateup_input_triton_kernel pre_permute_standard_to_deep_gemm _BF16GroupedContWarmupExecutor.execute

关键源码片段

python/sglang/srt/layers/moe/ep_moe/kernels.py core-logic

核心修复文件:修复专家 0 遗漏、引入 FP32 累积、添加 BF16 条件量化路径。

@triton.jit
def post_reorder_triton_kernel(
    down_output_ptr,
    output_ptr,
    src2dst_ptr,
    topk_ids_ptr,
    topk_weights_ptr,
    topk,
    hidden_size,
    BLOCK_SIZE: tl.constexpr,
):
    InDtype = down_output_ptr.dtype.element_ty
    src_idx_int32 = tl.program_id(0)
    src_idx = src_idx_int32.to(tl.int64)
    src2dst_ptr = src2dst_ptr + src_idx * topk
    topk_ids_ptr = topk_ids_ptr + src_idx * topk
    topk_weights_ptr = topk_weights_ptr + src_idx * topk
    store_ptr = output_ptr + src_idx * hidden_size
    vec = tl.arange(0, BLOCK_SIZE)
    for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
        offset = start_offset + vec
        mask = offset < hidden_size
        # 使用 FP32 累积以减少精度损失
        sum_vec = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
        for idx in range(topk):
            expert_id = tl.load(topk_ids_ptr + idx)
            # 关键修复:使用 >= 0 以包括专家 0(之前遗漏)
            if expert_id >= 0:
                dst_idx_int32 = tl.load(src2dst_ptr + idx)
                dst_idx = dst_idx_int32.to(tl.int64)
                weigh_scale = tl.load(topk_weights_ptr + idx).to(tl.float32)
                load_ptr = down_output_ptr + dst_idx * hidden_size
                # 在 FP32 中累加专家输出以获得更好精度,然后转换为最终输出 dtype
                in_data = tl.load(load_ptr + offset, mask=mask).to(tl.float32)
                sum_vec += in_data * weigh_scale
        tl.store(store_ptr + offset, sum_vec.to(InDtype), mask=mask)

@triton.jit
def _silu_and_mul_kernel(...):
    # ... 前置代码省略
    for token_index in tl.range(token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE):
        gate = tl.load(input_ptr_offs + token_index * stride_input_1, mask=offs_in_d < size_n, other=0.0).to(tl.float32)
        up = tl.load(input_ptr_offs + token_index * stride_input_1 + size_n, mask=offs_in_d < size_n, other=0.0).to(tl.float32)
        gate = gate / (1 + tl.exp(-gate))
        gate_up = up * gate
        # 在 FP32 中计算 SiLU 以提高精度,然后转换回输入 dtype
        gate_up = gate_up.to(input_ptr.dtype.element_ty)
        tl.store(output_ptr_offs + token_index * stride_output_1, gate_up, mask=offs_in_d < size_n)

def moe_ep_deepgemm_preprocess(
    topk_ids,
    num_local_experts,
    hidden_states,
    top_k,
    block_shape,
    output_dtype=torch.float8_e4m3fn, # 新增参数,默认为 FP8
):
    # ... 初始化代码省略
    is_fp8 = output_dtype == torch.float8_e4m3fn
    if is_fp8:
        # FP8 路径:执行缩放及量化
        ...
    else:
        # BF16 路径:跳过 FP8 量化,保持为 BF16
        ...

python/sglang/srt/layers/moe/moe_runner/deep_gemm.py core-logic

添加 output_dtype 推导逻辑,使预处理能够根据权重 dtype 决定输出格式,是 BF16 支持的关键。

@register_pre_permute("standard", "deep_gemm")
def pre_permute_standard_to_deep_gemm(
    dispatch_output: StandardDispatchOutput,
    quant_info: DeepGemmMoeQuantInfo,
    runner_config: MoeRunnerConfig,
    running_state: dict,
) -> DeepGemmRunnerInput:
    # ... 前置代码省略
    # 根据权重的 dtype 决定输出 dtype,匹配 runner 的 GEMM 分发逻辑
    output_dtype = (
        torch.bfloat16
        if quant_info.w13_weight.dtype == torch.bfloat16
        else torch.float8_e4m3fn
    )
    masked_m, expected_m, src2dst, hidden_states, hidden_states_scale = (
        moe_ep_deepgemm_preprocess(
            topk_ids,
            runner_config.num_local_experts,
            hidden_states,
            runner_config.top_k,
            quant_info.block_shape,
            output_dtype=output_dtype, # 传入 output_dtype
        )
    )
    # ... 后续代码不变

评论区精华

BF16 contiguous API 参数名兼容性 other

BBuf 指出 DeepGEMM PR #37 新增 BF16 contiguous API,但参数名改为 grouped_layout,现有预热代码使用 m_indices= 关键字,需要适配。

结论:作者移除关键字参数,BBuf 确认兼容性没问题。 · 已解决

up 张量应显式转为 float32 正确性

gemini-code-assist 机器人建议将 up 也显式转为 float32 以避免 Triton 编译错误。

结论:已被采纳,代码中 up 已使用 .to(tl.float32) 处理。 · 已解决

风险与影响

  • 核心路径变更:修改了 MoE 计算的关键 kernel(post_reorder_triton_kernel、_silu_and_mul_kernel、预处理函数),存在引入数值回归的风险,尤其是在已支持的 FP8 模型上。
  • 缺少单元测试覆盖:PR 没有新增测试用例,修复依赖于手动 GSM8K 验证。如果后期有其他分支修改了这些 kernel,可能难以快速发现回归。
  • 依赖兼容性:升级 sgl-deep-gemm 到 0.1.1 并调整 API 调用,若未来版本再次改变接口可能导致启动失败。
  • 用户影响:使用 --moe-runner-backend deep_gemm 并搭配标准 A2A 后端的用户将能正确运行 BF16 模型,FP8 模型精度在 GSM8K 上从 0.794 提升至 0.813。未使用该参数的用户无影响。
  • 系统影响:数值精度累积方式变更可能使结果与之前版本不一致,但属于预期提升。无显著性能退化(GSM8K 输出吞吐量从 1478 token/s 略升至 1486 token/s)。
  • 团队影响:需要保持对 sgl-deep-gemm 版本的跟踪,确保 API 向后兼容。
核心路径变更 缺少测试覆盖 依赖升级

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论