执行摘要
- 一句话:CuteDSL FP4 MoE 集成 FlashInfer alltoall
- 推荐动作:如果正在部署 Qwen3.5-397B-A17B-NVFP4 或其他 Blackwell FP4 大 MoE 模型且使用 DP+EP,建议仔细阅读此 PR 以理解 idle rank 处理和 alltoall 集成细节。关注重点包括强制禁用 NVFP4 dispatch 的权衡、sanitize kernel 的必要性、以及 buffer 大小配置的内存影响。目前 PR 已合并,后续可跟踪 NVFP4 dispatch 的支持进展。
功能与动机
在 B200×4 上部署超大 MoE 模型(如 Qwen3.5-397B NVFP4)时,需要 DP attention + EP 配置。原有的 CuteDSL MoE 后端不支持 FlashInfer alltoall,而 deepep 在部分场景不够灵活。本 PR 启用这一组合,并重点解决 idle DP rank(无本地 token)导致的 FP4 GEMM 崩溃和 NCCL 通信不匹配。
实现拆解
- 配置白名单与断言(
server_args.py):在 cutedsl 的 moe_a2a_backend 白名单加入 "flashinfer",在 flashinfer a2a 的 moe_runner_backend 白名单加入 "flashinfer_cutedsl"。强制禁用 SGLANG_MOE_NVFP4_DISPATCH(因 scale factor 布局不兼容)。验证 dp_size == tp_size 并调整 ep_size。增加容量检查确保 SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK * ep_size >= max_prefill_tokens。
- 模型前向适配(
qwen2_moe.py):在 forward 中添加 M=0 守卫分支——当 hidden_states 为空时跳过 shared_expert 和 gate,但仍调用 self.experts() 参与 alltoall。在 a2a 模式下跳过最终的 TP allreduce(因为 alltoall 已聚合 routed expert 结果)。shared_expert 初始化时传入 tp_size=1(与 deepep 一致)。
- Token 分发器重构(
flashinfer.py):移除预分配的 dummy tensor,直接传递空张量(kernel 原生支持)。重命名 has_dummy_token 为 is_idle_rank。在 dispatch 中设置 invalid_token_expert_id=self.num_experts 让 MoE 跳过 padding 槽位。默认 max_num_tokens 从 1024 提升至 8192。
- MoE 执行器注册与 buffer 适配(
flashinfer_cutedsl.py):新增 fused_experts_flashinfer_to_flashinfer_cutedsl_fp4 处理 NVFP4 和 bf16 两种输入。ensure_cutedsl_wrapper 增加 num_tokens 参数,动态计算 wrapper buffer 大小(a2a 路径根据 dispatcher 限制,allgather 路径根据首次 token 数)。
- 测试与 CI(新增
test_flashinfer_a2a_cutedsl_v2.py):E2E 测试在 B200×4 上启动 Qwen3.5-397B-NVFP4(TP4 EP4 DP4),运行 GSM8K 并断言准确率 >90%,归入 CI extra-b 阶段(约 600s)。
关键文件:
python/sglang/srt/layers/moe/moe_runner/flashinfer_cutedsl.py(模块 MoE 执行器;类别 source;类型 core-logic;符号 ensure_cutedsl_wrapper, fused_experts_flashinfer_to_flashinfer_cutedsl_fp4): 核心 MoE 执行器文件,新增 fused func 注册和 buffer size 适配,支持 flashinfer 分发器
test/registered/moe/test_flashinfer_a2a_cutedsl_v2.py(模块 测试;类别 test;类型 test-coverage;符号 TestCuteDslFlashinferA2A, setUpClass, tearDownClass, test_gsm8k): 新增 E2E 集成测试,在 B200×4 上验证 cutedsl + flashinfer a2a 组合的准确率
python/sglang/srt/server_args.py(模块 配置校验;类别 source;类型 configuration): 配置校验和兼容性断言,添加 flashinfer a2a 白名单,强制禁用 NVFP4 dispatch
python/sglang/srt/models/qwen2_moe.py(模块 模型适配;类别 source;类型 data-contract): 模型前向适配:M=0 guard、TP allreduce 跳过、shared_expert tp_size=1 支持 flashinfer a2a
python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py(模块 Token 分发器;类别 source;类型 core-logic): Token 分发器重构,移除 dummy token 机制,添加 invalid_token_expert_id,重命名 is_idle_rank
关键符号:ensure_cutedsl_wrapper, fused_experts_flashinfer_to_flashinfer_cutedsl_fp4, FlashinferDispatcher.dispatch, Qwen2MoeMoE.forward
关键源码片段
python/sglang/srt/layers/moe/moe_runner/flashinfer_cutedsl.py
核心 MoE 执行器文件,新增 fused func 注册和 buffer size 适配,支持 flashinfer 分发器
def ensure_cutedsl_wrapper(layer: torch.nn.Module, num_tokens: int = 0) -> None:
"""Lazily create CuteDslMoEWrapper and resolve scales on first forward.
Args:
layer: The FusedMoE layer module.
num_tokens: Current token count entering the MoE layer. Used as
the buffer size for the non-a2a (allgather) path, where the
autotune dummy run passes req_to_token_pool.size * dp_size —
the worst-case post-allgather batch. For the a2a path this
is ignored in favour of the dispatcher's workspace limit.
"""
if getattr(layer, "_cutedsl_wrapper", None) is not None:
return
try:
from flashinfer import CuteDslMoEWrapper
except ImportError as e:
raise ImportError(
"flashinfer_cutedsl backend requires FlashInfer with CuteDSL support. "
"Install with: pip install flashinfer"
) from e
from sglang.srt.server_args import get_global_server_args
assert layer.intermediate_size_per_partition > 0, (
f"CuteDSL MoE: intermediate_size_per_partition must be > 0, "
f"got {layer.intermediate_size_per_partition}. Check EP/TP configuration."
)
server_args = get_global_server_args()
use_cuda_graph = server_args is not None and not server_args.disable_cuda_graph
# Buffer size must cover the worst-case token count the MoE layer can see.
# - A2A path: dispatch returns tensors flattened from
# [ep_size, max_tokens_per_rank, ...].
# - Standard allgather path: dp_size * max local tokens per rank.
dispatcher = getattr(layer, "dispatcher", None)
if hasattr(dispatcher, "max_num_tokens"):
# A2A 路径:使用 dispatcher 限制的每 rank 最大 tokens 乘 ep_size
# 这样 buffer 大小可由用户通过环境变量 SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK 调节
max_num_tokens = dispatcher.max_num_tokens * getattr(dispatcher, "ep_size", 1)
else:
# 标准 allgather 路径:使用第一次 forward 传入的 num_tokens
# num_tokens 是 req_to_token_pool.size * dp_size,即最坏情况下的 post-allgather 批量大小
max_num_tokens = max(num_tokens, 1)
top_k = layer.top_k if layer.top_k is not None else layer.moe_runner_config.top_k
# ... 后续在 inference_mode 外创建 CuteDslMoEWrapper ...
python/sglang/srt/models/qwen2_moe.py
模型前向适配:M=0 guard、TP allreduce 跳过、shared_expert tp_size=1 支持 flashinfer a2a
def forward(self, hidden_states, forward_batch=None, use_reduce_scatter=False, should_allreduce_fusion=False):
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if get_moe_a2a_backend().is_deepep():
return self._forward_deepep(hidden_states, forward_batch)
# M=0 guard: idle DP rank 的 hidden_states 为空时,
# 跳过 shared_expert 和 gate (FP4 GEMM 无法处理空张量 ),
# 但必须调用 self.experts() 以参与 alltoall 集合通信。
if hidden_states.shape[0] == 0:
shared_output = None
topk_output = self.topk.empty_topk_output(hidden_states.device)
final_hidden_states = self.experts(hidden_states, topk_output)
elif self.alt_stream is not None and get_is_capture_mode():
final_hidden_states, shared_output = self.forward_normal_dual_stream(hidden_states)
else:
shared_output = self._forward_shared_experts(hidden_states)
final_hidden_states = self._forward_router_experts(hidden_states)
if shared_output is not None:
final_hidden_states += shared_output
# FlashInfer alltoall 已经聚合了 routed expert 结果,
# 因此跳过此步的 TP allreduce,否则 idle rank 会导致 NCCL 大小不匹配。
if (self.tp_size > 1
and not should_skip_post_experts_all_reduce(
is_tp_path=True,
use_reduce_scatter=use_reduce_scatter,
should_allreduce_fusion=should_allreduce_fusion,
)
and not get_moe_a2a_backend().is_flashinfer()):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py
Token 分发器重构,移除 dummy token 机制,添加 invalid_token_expert_id,重命名 is_idle_rank
def dispatch(self, x, topk_output):
topk_ids = topk_output.topk_ids
topk_weights = topk_output.topk_weights
# 检测 idle rank: 本地无 token 时,使用空张量(kernel 原生支持)
self.is_idle_rank = x.shape[0] == 0
if self.is_idle_rank:
x = hidden_states.new_zeros((1, self.hidden_size))
topk_ids = self.dummy_topk_ids
topk_weights = self.dummy_topk_weights
# 若启用 NVFP4 dispatch,量化输入
global_scale = self.quant_config.get("input_global_scale", None)
if global_scale is not None:
x, x_sf = fp4_quantize(x, global_scale, is_sf_swizzled_layout=False)
# ... 构建 payloads ...
if self.is_idle_rank:
self.runtime_max_tokens_per_rank = max(self.runtime_max_tokens_per_rank, 1)
# 调用 alltoall 时传入 invalid_token_expert_id = num_experts,
# 这样 MoE kernel 会跳过 padding 槽位,避免浪费计算。
recv_tensors = self.moe_a2a.dispatch(
topk_ids,
payloads,
self.runtime_max_tokens_per_rank,
invalid_token_expert_id=self.num_experts, # 标记 padding 为无效
expert_id_payload_index=expert_id_payload_index,
)
评论区精华
风险与影响
- 风险:
- NVFP4 dispatch 兼容性:由于 scale factor 布局不兼容,PR 强制禁用了 NVFP4 dispatch(
SGLANG_MOE_NVFP4_DISPATCH=0),导致 prefill 阶段输入保持 BFloat16,可能影响 prefill 性能。后续需验证并启用。
- 内存膨胀风险:
max_num_tokens 默认 8192,在大 EP(如 32)时预分配 buffer 可能超出显存。server_args 添加了容量检查,但用户需注意环境变量 SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK 的适当设置。
- sanitize kernel 性能:
invalid_token_expert_id 触发 moe_a2a_sanitize_expert_ids kernel,带来微小延迟,但结构上必须。
- 仅 FP4 量化:CuteDSL runner 仅支持 modelopt_fp4,其他量化类型会被断言拒绝。
- 与其他后端的互斥:启用 flashinfer a2a 时强制 shared_expert tp_size=1 并禁用 TP allreduce,与默认行为不同。
- 影响:
- 用户:主要影响使用 NVFP4 大 MoE 模型(如 Qwen3.5-397B)并启用 DP attention + EP 的用户,提供新的通信后端选择,可能改善扩展性和吞吐。现有配置向后兼容。
- 系统:增加
--moe-a2a-backend flashinfer 与 --moe-runner-backend flashinfer_cutedsl 的组合。在 B200×4 测试中准确率持平,吞吐提升 ~1%。其他 GPU 和模型未验证。
- 团队:需维护新的 fused func 和 dispatcher 逻辑。CI 新增 600s 的 E2E 测试。未来可考虑支持 NVFP4 dispatch。
- 风险标记:强制禁用 NVFP4 dispatch 路径, 大 DP 配置下内存压力, sanitize kernel 性能开销
关联脉络
- PR #25825 [Refactor] Pass PP start_layer via model constructor instead of forward_batch.token_to_kv_pool: 同样修改了 qwen2_moe.py,调整 MoE 层构造方式;本 PR 在此基础上增加了 flashinfer a2a 的 shared_expert tp_size=1 逻辑,需确保合并无冲突。
参与讨论