Prhub

#25189 [perf] DeepSeekV3: drop redundant FP32 upcasts in trtllm MoE paths

原始 PR 作者 maocheng23 合并时间 2026-05-23 05:23 文件变更 2 提交数 5 评论 10 代码增减 +3 / -14

执行摘要

移除 DeepSeekV3 trtllm MoE 冗余 FP32 转换

flashinfer 从 0.6.8 开始,trtllm_fp4_block_scale_moe、trtllm_fp8_block_scale_moe 和 trtllm_fp8_per_tensor_scale_moe 直接接受 BF16 router_logits,原代码中为 DeepSeekV3 路由做的显式 FP32 upcast 成为冗余操作。移除这些转换可以减少显存带宽消耗和计算开销,提升推理性能。PR body 还详细解释了为何不一次性修改 gate 输出 dtype:因为 gate 输出被所有 MoE runner 后端共享,每个后端影响不同,需要单独验证。

建议阅读 PR body 中的设计分析,了解作者为何分阶段清理 MoE 路径。该 PR 展示了如何基于版本依赖安全移除冗余操作,并为后续重构留下明确标记,值得 ML 工程团队借鉴。合并者 Qiaolin-Yu 和 reviewer b8zhong 均已批准。

讨论亮点

Reviewer ch-wan 在 deepseek_v2.py 上发表了唯一实质评论:"It's risky. Add a TODO item here for future discussion." 指出直接改动 gate dtype 存在风险,建议添加 TODO。作者采纳意见,在文件中添加了 TODO 注释。其他评论均为 CI rerun 命令,未涉及实质性技术讨论。

实现拆解

  1. 确认 flashinfer 版本条件:当前依赖锁定的 flashinfer 为 0.6.11.post1,满足 ≥0.6.8,移除 FP32 upcast 安全。
  2. 清理 FP8 block-scale 路径:在 fused_experts_none_to_flashinfer_trtllm_fp8 函数中,trtllm_fp8_block_scale_moe_wrapper 调用处删除根据 routing_method_type 条件转换为 FP32 的逻辑,直接传入原始的 router_logits。
  3. 简化 FP8 per-tensor 路径:该路径原先根据是否为 DeepSeekV3 分别转为 FP32 或 BF16;现将条件分支替换为无条件转为 BF16。
  4. 移除 FP4 路径中的显式转换:在 fused_experts_none_to_flashinfer_trtllm_fp4 函数中,删除 if routing_method_type == RoutingMethodType.DeepSeekV3 代码块。
  5. 添加 TODO 标记演进方向:在 deepseek_v2.py 的 MoEGate.router_gemm_forward 方法中,为 SM100/103 + 256-expert 快速路径添加 "TODO: will check the dtype to be bf16",标明后续应将 gate 输出改为 BF16。
文件 模块 状态 重要度
python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py MoE 层 modified 6.64
python/sglang/srt/models/deepseek_v2.py 模型定义 modified 4.39

关键符号

fused_experts_none_to_flashinfer_trtllm_fp8 fused_experts_none_to_flashinfer_trtllm_fp4 MoEGate.router_gemm_forward

关键源码片段

python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py core-logic

核心变更文件,移除三处显式 FP32 upcast,是性能收益的直接来源。

# flashinfer_trtllm.py – fused_experts_none_to_flashinfer_trtllm_fp8 FP8 per-tensor 路径
# flashinfer >= 0.6.8 已原生支持 BF16 router_logits,移除冗余条件转换a_q, _ = scaled_fp8_quant(hidden_states, quant_info.w13_input_scale)
routing_bias_cast = (
    None if correction_bias is None else correction_bias.to(torch.bfloat16)
)# 在对称内存上下文中分配输出
with use_symmetric_memory(get_tp_group(), disabled=not is_allocation_symmetric()):
    symm_output = torch.empty(
        hidden_states.shape[0],
        hidden_states.shape[1],
        dtype=torch.bfloat16,
        device=hidden_states.device,
    )router_logits = router_logits.to(torch.bfloat16)output = trtllm_fp8_per_tensor_scale_moe_wrapper(
    routing_logits=router_logits,
    routing_bias=routing_bias_cast,
    hidden_states=a_q,
    gemm1_weights=quant_info.w13_weight,
    output1_scales_scalar=quant_info.output1_scales_scalar,
    output1_scales_gate_scalar=quant_info.output1_scales_gate_scalar,
    gemm2_weights=quant_info.w2_weight,
    output2_scales_scalar=quant_info.output2_scales_scalar,
    num_experts=quant_info.global_num_experts,
    top_k=topk_config.top_k,
    n_group=topk_config.num_expert_group,
    topk_group=topk_config.topk_group,
    intermediate_size=int(quant_info.w2_weight.shape[2]),
    local_expert_offset=quant_info.local_expert_offset,
    local_num_experts=quant_info.local_num_experts,
    routed_scaling_factor=(
        runner_config.routed_scaling_factor
        if runner_config.routed_scaling_factor is not None
        else 1.0
    ),
    routing_method_type=routing_method_type,
    use_shuffled_weight=use_shuffled_weight,
    tune_max_num_tokens=next_power_of_2(a_q.shape[0]),
    fp8_quantization_type=int(fp8_quantization_type),
    activation_type=quant_info.activation_type,
)
python/sglang/srt/models/deepseek_v2.py data-contract

添加 TODO 注释标记后续 gate dtype 迁移方向,是分阶段策略的一部分。

# deepseek_v2.py – MoEGate.router_gemm forward(部分)
# SM100/103 256-expert 快速路径当前仍输出 FP32,添加 TODO 为后续 BF16 迁移做标记if _device_sm in [100, 103] and self.weight.shape[0] == 256:
    # TODO: will check the dtype to be bf16
    # router gemm output float32
    logits = torch.empty(
        hidden_states.shape[0],
        self.weight.shape[0],
        device=hidden_states.device,
        dtype=torch.float32,
    )
    flashinfer_dsv3_router_gemm(logits, hidden_states, self.weight)
# ...

评论区精华

风险提示与 TODO 添加 设计

ch-wan 评论指出直接改动 gate dtype 风险较高,建议添加 TODO。作者采纳并添加了 TODO 注释。

结论:保留 gate 输出为 FP32,将 BF16 迁移延后;在 deepseek_v2.py 中添加 TODO 追踪。 · 已解决

风险与影响

主要风险:仅修改 trtllm 路径而 gate 仍输出 FP32,若未来统一为 BF16 时可能引入精度或兼容性问题。但当前改动在 flashinfer 已确认支持 BF16 的前提下进行,且未改变其他后端行为,回归风险较低。未新增测试用例,依赖现有 DeepSeek-V3 推理测试和 FP4/FP8 MoE 准确性测试覆盖。

影响范围仅限于使用 flashinfer trtllm 相关 MoE 路径的 DeepSeek-V3/V3.2 模型推理。用户无需配置变更即可获得性能提升(减少显存带宽和 cast 开销)。对其他模型和 MoE 后端无影响。gate 输出仍为 FP32,因此精度不变。

核心路径变更 分阶段迁移风险 缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论