Prhub

#23269 Support batch size > 1 when enable CP

原始 PR 作者 Shunkangz 合并时间 2026-05-28 05:11 文件变更 13 提交数 13 评论 51 代码增减 +268 / -305

执行摘要

上下文并行支持 batch size > 1

PR body 中明确目标:"Enable batch size > 1 with context parallel." 此前 CP 的 zigzag 切分和元数据仅支持单请求,严重限制了吞吐。Review 中 kpham-sgl 指出当前 MHA CP 实现中存在两处已知 bug,并强调需要 attn_cp_size == 4 且 bs > 1 的测试覆盖。

值得精读,尤其 ContextParallelMetadata 从单序列到多序列的设计演进,以及 padding 策略的权衡。讨论中的 CPU 开销担忧和未来 Triton 替代方向值得关注。架构师应关注 DSA 路径的遗留 TODO。

讨论亮点
  • can_cp_split 异常 vs graceful:kpham-sgl 担心抛异常会导致生产崩溃,建议回退;Shunkangz 最初认为应显式暴露,但最终采用逐请求检查并 return False 的 graceful 方案。
  • padding 到 2×cp_size:kpham-sgl 提议将 padding 从 cp_size 改为 2×cp_size 以支持 zigzag 负载均衡,被采纳并实现在 forward_batch_info.py
  • cu_seqlens 移至 prepare_context_parallel_metadata:Shunkangz 解释是为了减少 kernel launch 前的 CPU 索引拷贝开销,促进 kernel 尽早发射。
  • CPU 元数据开销担忧:Fridge003 建议未来用 Triton kernel 替代 CPU 循环,kpham-sgl 表示 +1,目前未实施但有 TODO。
  • DSA indexer 仍假设 batch=1:kpham-sgl 指出 dsa_indexer 中 *_list[0] 只适用于单请求,Shunkangz 确认保持 TODO 以便后续多 batch 支持。

实现拆解

  1. ContextParallelMetadata 数据结构重构cp_utils.py):将 kv_len_prevkv_len_next 等标量替换为形状 [bs] 的 CUDA tensor(如 kv_len_prev_tensor),新增 cu_seqlens_q_prev_tensor 辅助 FlashAttention,新增 bs 字段记录批大小,total_seq_lens 改为 int。
  2. CP 切分判断升级cp_utils.pycan_cp_split):移除 batch_size == 1 限制,改为逐个检查 extend_seq_lens_cpu 是否满足 len >= cp_size * 2,不满足时返回 False 而非抛异常。
  3. Padding 对齐粒度调整forward_batch_info.py):对齐基数从 attn_cp_size 改为 attn_cp_size * 2,确保 zigzag 切分负载均衡。
  4. 调度策略松绑schedule_policy.py):移除 self.prefill_context_parallel_enabledadd_one_reqcan_run_list 长度的限制,允许多请求进入。
  5. 注意力后端适配:在 cp_attn_forward_extend 中按 total_q_prev_tokens 切分 q,传入 cu_seqlens_q_prev/next_tensor;在 dsa_indexer.py 中将标量字段替换为 kv_len_prev_list[0] 等列表元素(仍假设 batch=1,预留 TODO)。
  6. 调用参数统一:所有模型入口(deepseek_v2.pydeepseek_nextn.pydeepseek_v4.pyqwen3_moe.py 等)将 prepare_context_parallel_metadataextend_lens= 更名为 extend_seqs_len=
  7. 测试配套:删除 test_qwen3_30b.py(原 CP 精度测试),修改 test_mla_cp_fa3_parity.py 适配新字段。
文件 模块 状态 重要度
python/sglang/srt/layers/utils/cp_utils.py CP 核心 modified 7.97
python/sglang/srt/layers/attention/dsa/dsa_indexer.py DSA 索引器 modified 5.8
python/sglang/srt/models/deepseek_v2.py DeepSeek 模型 modified 5.28
python/sglang/srt/models/deepseek_nextn.py NextN 模型 modified 5.28
python/sglang/srt/managers/schedule_policy.py 调度器 modified 5.47
python/sglang/srt/model_executor/forward_batch_info.py 批处理信息 modified 5.17
test/registered/cp/test_qwen3_30b.py CP 测试 removed 6.95
test/registered/kernels/test_mla_cp_fa3_parity.py MLA CP 测试 modified 4.3

关键符号

can_cp_split prepare_context_parallel_metadata cp_attn_forward_extend cp_all_gather_reorganized_into_tensor add_one_req prepare_mlp_sync_batch

关键源码片段

python/sglang/srt/layers/utils/cp_utils.py core-logic

核心变更文件,ContextParallelMetadata 数据类重构支持多 batch,can_cp_split 逻辑调整,cp_attn_forward_extend 适配多序列。

# ContextParallelMetadata 支持 batch size > 1
@dataclass
class ContextParallelMetadata:
    # Layout lists have length bs * cp_segment_num (= bs * 2 * cp_size).
    split_list: List[int] = None
    zigzag_index: List[int] = None
    cp_reverse_index: List[int] = None
    reverse_split_len: List[int] = None
​
    # Per-rank-aggregate lists have length cp_size.
    per_rank_actual_token: List[int] = None
    max_rank_len: List[int] = None
​
    # Per-sequence FlashAttention tensors (shape [bs] or [bs+1]).
    kv_len_prev_tensor: torch.Tensor = None # [bs] int32 CUDA
    kv_len_next_tensor: torch.Tensor = None # [bs] int32 CUDA
    actual_seq_q_prev_tensor: torch.Tensor = None # [bs] int32 CUDA
    actual_seq_q_next_tensor: torch.Tensor = None # [bs] int32 CUDA
    cu_seqlens_q_prev_tensor: torch.Tensor = None # [bs+1] int32 CUDA
    cu_seqlens_q_next_tensor: torch.Tensor = None # [bs+1] int32 CUDA
​
    # Per-seq CPU lists (useful for NSA indexer and diagnostics).
    kv_len_prev_list: List[int] = None
    kv_len_next_list: List[int] = None
    actual_seq_q_prev_list: List[int] = None
    actual_seq_q_next_list: List[int] = None
​
    # Aggregate sum of extend_seq_lens across the batch.
    total_seq_lens: int = 0
    bs: int = 1
​
​
def can_cp_split(seq_len: int, cp_size: int, forward_batch):
    # 基础条件 : CP 开启、size>1、纯 extend 模式
    from sglang.srt.model_executor.forward_batch_info import ForwardMode
    cur_cp_seq_len = seq_len // (cp_size * 2)
    if not (
        cur_cp_seq_len != 0
        and cp_size > 1
        and forward_batch.forward_mode.is_context_parallel_extend()
        and forward_batch.forward_mode != ForwardMode.MIXED
        and is_prefill_context_parallel_enabled()
    ):
        return False
​
    # 逐请求检查 extend length 是否足够 zigzag 切分
    extend_lens = getattr(forward_batch, "extend_seq_lens_cpu", None)
    if extend_lens is None:
        return True
​
    cp_min = cp_size * 2
    for L in extend_lens:
        if L < cp_min:
            # 不满足切分条件的请求 gracefully 回退到非 CP 模式
            return False
    return True
python/sglang/srt/layers/attention/dsa/dsa_indexer.py core-logic

DSA 注意力路径适配,将标量字段改为从列表中取元素,仍假设 batch=1。

# head 版本 : 从 list 中取第一个元素,仍假设 batch=1 的 DSA 路径
if (
    forward_batch.attn_cp_metadata is not None
    and is_dsa_prefill_cp_in_seq_split()
):
    kv_len_prev = forward_batch.attn_cp_metadata.kv_len_prev_list[0]
    kv_len_next = forward_batch.attn_cp_metadata.kv_len_next_list[0]
    actual_seq_q_prev = forward_batch.attn_cp_metadata.actual_seq_q_prev_list[0]
    actual_seq_q_next = forward_batch.attn_cp_metadata.actual_seq_q_next_list[0]
    # TODO: 支持 multi-batch 后需改为对应 batch 索引

评论区精华

can_cp_split 抛异常 vs graceful fallback 设计

kpham-sgl 认为抛异常会崩溃生产环境,建议 return False;Shunkangz 最初想显式暴露问题,后同意改为 graceful fallback。

结论:采用逐请求检查 extend_seq_len,不满足时 return False,由调度器回退到非 CP 模式。 · 已解决

padding 对齐粒度从 cp_size 改为 cp_size*2 设计

kpham-sgl 指出原 padding 只对齐到 cp_size,无法保证 zigzag 负载均衡,建议改为 2*cp_size。

结论:在 forward_batch_info.py 中实施对齐到 attn_cp_size * 2。 · 已解决

cu_seqlens 计算移至 prepare_context_parallel_metadata 性能

Shunkangz 解释移动目的是减少 kernel launch 前的 CPU 索引拷贝开销,让 kernel 尽早发射。

结论:接受此改动,将 cu_seqlens 计算提前到元数据准备阶段。 · 已解决

CPU 元数据计算开销担忧 性能

Fridge003 建议未来用 Triton kernel 替代 CPU 循环,kpham-sgl 表示 +1,Shunkangz 认为当前不是瓶颈但同意后续优化。

结论:留下 TODO,后续可考虑 Triton 实现。 · acknowledged

DSA indexer 仍假设 batch=1 设计

kpham-sgl 指出 dsa_indexer 中 '*_list[0]' 只适用于单请求,需要待后续支持。

结论:保留 TODO,DSA 多 batch 支持推迟到后续 PR。 · acknowledged

风险与影响

  • 核心路径回归ContextParallelMetadata 字段大幅变更,所有 CP 相关的注意力、前向、DP padding 逻辑均受影响,需关注 DeepSeek V3/V4、Qwen3 等模型的精度和 crash。
  • CPU 开销增加prepare_context_parallel_metadata 中对每序列的逐循环计算可能成为 prefill 瓶颈,尤其当 bs 较大时。
  • DSA 路径兼容性dsa_indexer.py 仍使用 kv_len_prev_list[0],在真正多 batch 前会限制 DSA CP 的 batch 能力。
  • 测试覆盖缩减:删除了 test_qwen3_30b.py(原 CP 精度测试),若 test_mla_cp_fa3_parity.py 覆盖不足可能漏掉回归。
  • 用户影响:启用 CP 时可获得更高 prefill 吞吐,但需要每个请求的 extend length 足够长(≥2×cp_size),否则回退到非 CP 模式。
  • 系统影响:新增 attn_cp_size * 2 对齐计算,可能增加少量显存/计算开销;调度器不再强制 batch=1,可能改变并发行为。
  • 团队影响:后续开发(如 DSA 多 batch、Triton 元数据 kernel)需要在此数据结构基础上继续迭代。
核心路径变更 CPU 开销增加 DSA 路径兼容性 测试覆盖缩减

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论