Prhub

#22669 feat: Support flashinfer_cutedsl MoE runner with flashinfer alltoall backend

原始 PR 作者 samuellees 合并时间 2026-05-20 15:36 文件变更 6 提交数 19 评论 22 代码增减 +262 / -52

执行摘要

CuteDSL FP4 MoE 集成 FlashInfer alltoall

在 B200×4 上部署超大 MoE 模型(如 Qwen3.5-397B NVFP4)时,需要 DP attention + EP 配置。原有的 CuteDSL MoE 后端不支持 FlashInfer alltoall,而 deepep 在部分场景不够灵活。本 PR 启用这一组合,并重点解决 idle DP rank(无本地 token)导致的 FP4 GEMM 崩溃和 NCCL 通信不匹配。

如果正在部署 Qwen3.5-397B-A17B-NVFP4 或其他 Blackwell FP4 大 MoE 模型且使用 DP+EP,建议仔细阅读此 PR 以理解 idle rank 处理和 alltoall 集成细节。关注重点包括强制禁用 NVFP4 dispatch 的权衡、sanitize kernel 的必要性、以及 buffer 大小配置的内存影响。目前 PR 已合并,后续可跟踪 NVFP4 dispatch 的支持进展。

讨论亮点
  • 变量命名:YAMY1234 建议将 has_dummy_token 改为 is_idle_rank 以提升可读性,已采纳。
  • NVFP4 dispatch:YAMY1234 和 leejnau 建议强制禁用 NVFP4 dispatch(因未验证),trevor-m 认为应支持并推送了 commit,但最终 PR 仍以强制禁用并配置警告告终。
  • sanitize kernel 性能:trevor-m 质疑 invalid_token_expert_id 带来的 moe_a2a_sanitize_expert_ids 开销,bobboli 解释 padding 槽必须有无效 expert_id 以避免浪费计算,结论是结构上必需。
  • 默认 max_num_tokens 内存风险:trevor-m 指出默认 8192 在 DEP32 下可能导致 225GB buffer,作者在 server_args.py 中添加了容量校验逻辑。
  • CI 阶段调整:ch-wan 建议将新测试归入 extra-b 阶段以节省 CI 资源,已修改。

实现拆解

  1. 配置白名单与断言server_args.py):在 cutedsl 的 moe_a2a_backend 白名单加入 "flashinfer",在 flashinfer a2a 的 moe_runner_backend 白名单加入 "flashinfer_cutedsl"。强制禁用 SGLANG_MOE_NVFP4_DISPATCH(因 scale factor 布局不兼容)。验证 dp_size == tp_size 并调整 ep_size。增加容量检查确保 SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK * ep_size >= max_prefill_tokens
  2. 模型前向适配qwen2_moe.py):在 forward 中添加 M=0 守卫分支——当 hidden_states 为空时跳过 shared_expert 和 gate,但仍调用 self.experts() 参与 alltoall。在 a2a 模式下跳过最终的 TP allreduce(因为 alltoall 已聚合 routed expert 结果)。shared_expert 初始化时传入 tp_size=1(与 deepep 一致)。
  3. Token 分发器重构flashinfer.py):移除预分配的 dummy tensor,直接传递空张量(kernel 原生支持)。重命名 has_dummy_tokenis_idle_rank。在 dispatch 中设置 invalid_token_expert_id=self.num_experts 让 MoE 跳过 padding 槽位。默认 max_num_tokens 从 1024 提升至 8192。
  4. MoE 执行器注册与 buffer 适配flashinfer_cutedsl.py):新增 fused_experts_flashinfer_to_flashinfer_cutedsl_fp4 处理 NVFP4 和 bf16 两种输入。ensure_cutedsl_wrapper 增加 num_tokens 参数,动态计算 wrapper buffer 大小(a2a 路径根据 dispatcher 限制,allgather 路径根据首次 token 数)。
  5. 测试与 CI(新增 test_flashinfer_a2a_cutedsl_v2.py):E2E 测试在 B200×4 上启动 Qwen3.5-397B-NVFP4(TP4 EP4 DP4),运行 GSM8K 并断言准确率 >90%,归入 CI extra-b 阶段(约 600s)。
文件 模块 状态 重要度
python/sglang/srt/layers/moe/moe_runner/flashinfer_cutedsl.py MoE 执行器 modified 8.06
test/registered/moe/test_flashinfer_a2a_cutedsl_v2.py 测试 added 7.12
python/sglang/srt/server_args.py 配置校验 modified 6.58
python/sglang/srt/models/qwen2_moe.py 模型适配 modified 6.83
python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py Token 分发器 modified 6.7

关键符号

ensure_cutedsl_wrapper fused_experts_flashinfer_to_flashinfer_cutedsl_fp4 FlashinferDispatcher.dispatch Qwen2MoeMoE.forward

关键源码片段

python/sglang/srt/layers/moe/moe_runner/flashinfer_cutedsl.py core-logic

核心 MoE 执行器文件,新增 fused func 注册和 buffer size 适配,支持 flashinfer 分发器

def ensure_cutedsl_wrapper(layer: torch.nn.Module, num_tokens: int = 0) -> None:
    """Lazily create CuteDslMoEWrapper and resolve scales on first forward.    Args:
        layer: The FusedMoE layer module.
        num_tokens: Current token count entering the MoE layer. Used as
            the buffer size for the non-a2a (allgather) path, where the
            autotune dummy run passes req_to_token_pool.size * dp_size —
            the worst-case post-allgather batch. For the a2a path this
            is ignored in favour of the dispatcher's workspace limit.
    """
    if getattr(layer, "_cutedsl_wrapper", None) is not None:
        return
​
    try:
        from flashinfer import CuteDslMoEWrapper
    except ImportError as e:
        raise ImportError(
            "flashinfer_cutedsl backend requires FlashInfer with CuteDSL support. "
            "Install with: pip install flashinfer"
        ) from e
​
    from sglang.srt.server_args import get_global_server_args
​
    assert layer.intermediate_size_per_partition > 0, (
        f"CuteDSL MoE: intermediate_size_per_partition must be > 0, "
        f"got {layer.intermediate_size_per_partition}. Check EP/TP configuration."
    )
​
    server_args = get_global_server_args()
    use_cuda_graph = server_args is not None and not server_args.disable_cuda_graph
​
    # Buffer size must cover the worst-case token count the MoE layer can see.
    # - A2A path: dispatch returns tensors flattened from
    # [ep_size, max_tokens_per_rank, ...].
    # - Standard allgather path: dp_size * max local tokens per rank.
    dispatcher = getattr(layer, "dispatcher", None)
    if hasattr(dispatcher, "max_num_tokens"):
        # A2A 路径:使用 dispatcher 限制的每 rank 最大 tokens 乘 ep_size
        # 这样 buffer 大小可由用户通过环境变量 SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK 调节
        max_num_tokens = dispatcher.max_num_tokens * getattr(dispatcher, "ep_size", 1)
    else:
        # 标准 allgather 路径:使用第一次 forward 传入的 num_tokens
        # num_tokens 是 req_to_token_pool.size * dp_size,即最坏情况下的 post-allgather 批量大小
        max_num_tokens = max(num_tokens, 1)
​
    top_k = layer.top_k if layer.top_k is not None else layer.moe_runner_config.top_k
    # ... 后续在 inference_mode 外创建 CuteDslMoEWrapper ...
python/sglang/srt/models/qwen2_moe.py data-contract

模型前向适配:M=0 guard、TP allreduce 跳过、shared_expert tp_size=1 支持 flashinfer a2a

def forward(self, hidden_states, forward_batch=None, use_reduce_scatter=False, should_allreduce_fusion=False):
    num_tokens, hidden_dim = hidden_states.shape
    hidden_states = hidden_states.view(-1, hidden_dim)
​
    if get_moe_a2a_backend().is_deepep():
        return self._forward_deepep(hidden_states, forward_batch)
​
    # M=0 guard: idle DP rank 的 hidden_states 为空时,
    # 跳过 shared_expert 和 gate (FP4 GEMM 无法处理空张量 ),
    # 但必须调用 self.experts() 以参与 alltoall 集合通信。
    if hidden_states.shape[0] == 0:
        shared_output = None
        topk_output = self.topk.empty_topk_output(hidden_states.device)
        final_hidden_states = self.experts(hidden_states, topk_output)
    elif self.alt_stream is not None and get_is_capture_mode():
        final_hidden_states, shared_output = self.forward_normal_dual_stream(hidden_states)
    else:
        shared_output = self._forward_shared_experts(hidden_states)
        final_hidden_states = self._forward_router_experts(hidden_states)
​
    if shared_output is not None:
        final_hidden_states += shared_output
​
    # FlashInfer alltoall 已经聚合了 routed expert 结果,
    # 因此跳过此步的 TP allreduce,否则 idle rank 会导致 NCCL 大小不匹配。
    if (self.tp_size > 1
            and not should_skip_post_experts_all_reduce(
                is_tp_path=True,
                use_reduce_scatter=use_reduce_scatter,
                should_allreduce_fusion=should_allreduce_fusion,
            )
            and not get_moe_a2a_backend().is_flashinfer()):
        final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
​
    return final_hidden_states.view(num_tokens, hidden_dim)
python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py core-logic

Token 分发器重构,移除 dummy token 机制,添加 invalid_token_expert_id,重命名 is_idle_rank

def dispatch(self, x, topk_output):
    topk_ids = topk_output.topk_ids
    topk_weights = topk_output.topk_weights
​
    # 检测 idle rank: 本地无 token 时,使用空张量(kernel 原生支持)
    self.is_idle_rank = x.shape[0] == 0
    if self.is_idle_rank:
        x = hidden_states.new_zeros((1, self.hidden_size))
        topk_ids = self.dummy_topk_ids
        topk_weights = self.dummy_topk_weights
​
    # 若启用 NVFP4 dispatch,量化输入
    global_scale = self.quant_config.get("input_global_scale", None)
    if global_scale is not None:
        x, x_sf = fp4_quantize(x, global_scale, is_sf_swizzled_layout=False)
​
    # ... 构建 payloads ...
​
    if self.is_idle_rank:
        self.runtime_max_tokens_per_rank = max(self.runtime_max_tokens_per_rank, 1)
​
    # 调用 alltoall 时传入 invalid_token_expert_id = num_experts,
    # 这样 MoE kernel 会跳过 padding 槽位,避免浪费计算。
    recv_tensors = self.moe_a2a.dispatch(
        topk_ids,
        payloads,
        self.runtime_max_tokens_per_rank,
        invalid_token_expert_id=self.num_experts, # 标记 padding 为无效
        expert_id_payload_index=expert_id_payload_index,
    )

评论区精华

变量命名:has_dummy_token -> is_idle_rank style

YAMY1234 建议将 has_dummy_token 重命名为 is_idle_rank 以提升可读性。

结论:作者采纳并修改。 · 已解决

NVFP4 dispatch 与 CuteDSL scale factor 兼容性 设计

YAMY1234 询问 NVFP4 dispatch 路径是否验证,建议强制禁用;leejnau 建议在 server_args 中显式设置 env;trevor-m 认为应支持并推送 commit。

结论:因 scale factor 布局不兼容,暂强制禁用 NVFP4 dispatch,发出 warning。 · 已解决

invalid_token_expert_id 引起的 sanitize kernel 性能开销 性能

trevor-m 质疑 sanitize kernel 是否必要;bobboli 解释 padding 槽必须有无效 expert_id 否则浪费计算。

结论:保留 invalid_token_expert_id,结构上必需。 · 已解决

默认 max_num_tokens 提高导致内存膨胀 性能

trevor-m 指出默认 8192 在 DEP32 下 buffer 达 225GB。

结论:在 server_args 中增加容量校验,确保 SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK * ep_size >= max_prefill_tokens。 · 已解决

CI 测试阶段调整 other

ch-wan 建议将测试移到 extra-b 阶段以节约 CI 资源。

结论:已修改测试注册中的 stage 为 extra-b。 · 已解决

风险与影响

  • NVFP4 dispatch 兼容性:由于 scale factor 布局不兼容,PR 强制禁用了 NVFP4 dispatch(SGLANG_MOE_NVFP4_DISPATCH=0),导致 prefill 阶段输入保持 BFloat16,可能影响 prefill 性能。后续需验证并启用。
  • 内存膨胀风险max_num_tokens 默认 8192,在大 EP(如 32)时预分配 buffer 可能超出显存。server_args 添加了容量检查,但用户需注意环境变量 SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK 的适当设置。
  • sanitize kernel 性能invalid_token_expert_id 触发 moe_a2a_sanitize_expert_ids kernel,带来微小延迟,但结构上必须。
  • 仅 FP4 量化:CuteDSL runner 仅支持 modelopt_fp4,其他量化类型会被断言拒绝。
  • 与其他后端的互斥:启用 flashinfer a2a 时强制 shared_expert tp_size=1 并禁用 TP allreduce,与默认行为不同。
  • 用户:主要影响使用 NVFP4 大 MoE 模型(如 Qwen3.5-397B)并启用 DP attention + EP 的用户,提供新的通信后端选择,可能改善扩展性和吞吐。现有配置向后兼容。
  • 系统:增加 --moe-a2a-backend flashinfer--moe-runner-backend flashinfer_cutedsl 的组合。在 B200×4 测试中准确率持平,吞吐提升 ~1%。其他 GPU 和模型未验证。
  • 团队:需维护新的 fused func 和 dispatcher 逻辑。CI 新增 600s 的 E2E 测试。未来可考虑支持 NVFP4 dispatch。
强制禁用 NVFP4 dispatch 路径 大 DP 配置下内存压力 sanitize kernel 性能开销

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论