Prhub

#16859 [RL] DeepEP support for `--enable-return-routed-experts`

原始 PR 作者 PrinsYin 合并时间 2026-05-06 11:01 文件变更 2 提交数 8 评论 5 代码增减 +66 / -28

执行摘要

DeepEP 支持 routed experts 捕获与 all-gather

从 slime 补丁迁移回上游(issue #1316),需要在 DeepEP a2a 模式下正确捕获 routed experts。之前的 tp=2 dp=2 测试配置导致 attn_tp_size=1,gather 路径从未被执行。

值得精读,尤其是 capture_get_local_slice 的设计权衡,以及测试如何构造有效覆盖。对从事分布式 MoE 和 RL 捕获的同学有参考价值。

讨论亮点

PR body 中提及另一种 late-gather 实现(#17892 由 ocss884 提出),在 D2H 同步时才 gather;本 PR 保持 early-gather 方式(capture 时 gather)。关于测试配置,commit 历史显示逐步调整:先改为 tp=4 dp+deepep,然后固定 baseline/reference 仅变 perf 标志,最后因 DeepEP 正常模式要求改为 FP8 模型并强制 --deepep-mode normal 以避免低延迟模式 buffer 不足。

实现拆解

  1. 导入新增与预分配 gather bufferpython/sglang/srt/state_capturer/routed_experts.py):新增 attn_tp_all_gather_into_tensorget_attention_tp_sizeget_moe_a2a_backend 导入;在 __init__ 中,若后端为 DeepEP,则预分配一个 gather_buffer,大小为 device_cache.buffer.shape[0] * attn_tp_size,用于容纳 all-gather 后的完整 topk_ids。
  2. 重写 capture 方法:新增 capture 方法,当 DeepEP 启用时,先保存局部 topk_indices,然后从 gather_buffer 中切片目标区域,调用 attn_tp_all_gather_into_tensor 执行 all-gather,最后调用父类的 capture 将合并后的数据写入设备缓存。
  3. 调整 _get_local_slice 条件:原来在 DP attention 下基于 DP rank 切片,现在对于 DeepEP 模式,capture 已将所有 rank 数据 gather 到 buffer 头部,因此直接使用 [0:N_local] 而不是全局偏移量,故增加 not get_moe_a2a_backend().is_deepep() 条件。
  4. 测试重写test/registered/rl/test_return_routed_experts.py):将 baseline 和 reference 统一为 --tp 4 --dp 2 --enable-dp-attention --moe-a2a-backend deepep --deepep-mode normal,仅切换性能标志(overlap/cuda-graph/radix),确保 gather 路径被真正执行;模型换为 FP8 版本以兼容 DeepEP normal 模式(Bf16 断言过时)。
文件 模块 状态 重要度
python/sglang/srt/state_capturer/routed_experts.py 状态捕获 modified 7.27
test/registered/rl/test_return_routed_experts.py 测试 modified 5.64

关键符号

RoutedExpertsCapturer.create RoutedExpertsCapturer.__init__ RoutedExpertsCapturer.capture RoutedExpertsCapturer._get_local_slice

关键源码片段

python/sglang/srt/state_capturer/routed_experts.py entrypoint

核心实现:新增 DeepEP 路径的 capture all-gather 与 DP 切片逻辑调整

# python/sglang/srt/state_capturer/routed_experts.py # head 版本
​
    def capture(self, layer_id: int, topk_indices: torch.Tensor):
        # 在 DeepEP 模式下,每个 attn-TP rank 只持有 topk_ids 的散列切片,
        # 需要在写入 device cache 之前跨 attn-TP 做 all-gather 恢复完整视图。
        if get_moe_a2a_backend().is_deepep():
            local_topk = topk_indices
            # gather_buffer 预分配的空间足够容纳所有 attn-TP rank 的拼接结果
            topk_indices = self.gather_buffer[
                : local_topk.size(0) * get_attention_tp_size()
            ]
            attn_tp_all_gather_into_tensor(topk_indices, local_topk)
        # 将(可能已 gather 的)topk_indices 写入设备缓存
        super().capture(layer_id, topk_indices)
​
    def _get_local_slice(
        self,
        forward_batch: ForwardBatch,
        can_run_graph: bool,
        cuda_graph_batch: Optional[int],
    ) -> torch.Tensor:
        # 在 DeepEP 路径下,capture() 已经将全局数据 gather 到 buffer 起始位置,
        # 每个 DP rank 的数据位于 [0:N_local] 而非全局偏移 [start_pos:end_pos]。
        # 因此仅在非 DeepEP 的 DP attention 场景才需要做 DP-rank 感知切片。
        if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep():
            local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
            if can_run_graph:
                local_start_pos = get_attention_dp_rank() * cuda_graph_batch
            local_end_pos = local_start_pos + local_num_tokens
        else:
            local_start_pos, local_end_pos = 0, forward_batch.out_cache_loc.shape[0]
        return self.device_cache.buffer[
            local_start_pos:local_end_pos, :, : self.topk_size
        ]

评论区精华

Early-gather vs late-gather for DeepEP all-gather 设计

PR body 提到 ocss884 在 #17892 中实现了 late-gather(在 D2H 同步时 gather),而本 PR 采用 early-gather(在 capture 时 gather)。early-gather 保持现有 _get_local_slice / D2H 路径不变,但增加了预分配 gather_buffer 的显存开销。

结论:本 PR 采用 early-gather 方案,已合并。 · 已解决

风险与影响

仅测试了 DeepEP 后端,其他 MoE a2a 后端(如 libuv)上该代码路径不会触发,但未测试回归。gather_buffer 预分配会占用额外显存,显存开销随 attn_tp_size 线性增加。测试仅覆盖 H100 4-GPU 环境,AMD 和低端 GPU 未验证。另外,_get_local_slice 修改后,当 is_dp_attention_enabled() and not is_deepep() 时行为不变,但条件变化可能影响未来新后端引入时的正确性。

直接影响使用 --enable-return-routed-experts 且配合 DeepEP a2a 后端的用户,现在能正确获取 routed experts 信息。不影响非 DeepEP 用户。测试覆盖增加,但需要 4 GPU H100 资源,CI 运行时间 400 秒。团队需维护 early-gather 实现,并与可能的 late-gather 方案保持一致性。

仅覆盖 DeepEP 后端 显存开销增加 测试限于 FP8 模型 AMD CI 被禁用

关联 Issue

#1316 Mitigate content from sglang.patch to sglang

完整报告

参与讨论