执行摘要
- 一句话:上下文并行支持 batch size > 1
- 推荐动作:值得精读,尤其
ContextParallelMetadata 从单序列到多序列的设计演进,以及 padding 策略的权衡。讨论中的 CPU 开销担忧和未来 Triton 替代方向值得关注。架构师应关注 DSA 路径的遗留 TODO。
功能与动机
PR body 中明确目标:"Enable batch size > 1 with context parallel." 此前 CP 的 zigzag 切分和元数据仅支持单请求,严重限制了吞吐。Review 中 kpham-sgl 指出当前 MHA CP 实现中存在两处已知 bug,并强调需要 attn_cp_size == 4 且 bs > 1 的测试覆盖。
实现拆解
- ContextParallelMetadata 数据结构重构(
cp_utils.py):将 kv_len_prev、kv_len_next 等标量替换为形状 [bs] 的 CUDA tensor(如 kv_len_prev_tensor),新增 cu_seqlens_q_prev_tensor 辅助 FlashAttention,新增 bs 字段记录批大小,total_seq_lens 改为 int。
- CP 切分判断升级(
cp_utils.py 中 can_cp_split):移除 batch_size == 1 限制,改为逐个检查 extend_seq_lens_cpu 是否满足 len >= cp_size * 2,不满足时返回 False 而非抛异常。
- Padding 对齐粒度调整(
forward_batch_info.py):对齐基数从 attn_cp_size 改为 attn_cp_size * 2,确保 zigzag 切分负载均衡。
- 调度策略松绑(
schedule_policy.py):移除 self.prefill_context_parallel_enabled 对 add_one_req 中 can_run_list 长度的限制,允许多请求进入。
- 注意力后端适配:在
cp_attn_forward_extend 中按 total_q_prev_tokens 切分 q,传入 cu_seqlens_q_prev/next_tensor;在 dsa_indexer.py 中将标量字段替换为 kv_len_prev_list[0] 等列表元素(仍假设 batch=1,预留 TODO)。
- 调用参数统一:所有模型入口(
deepseek_v2.py、deepseek_nextn.py、deepseek_v4.py、qwen3_moe.py 等)将 prepare_context_parallel_metadata 的 extend_lens= 更名为 extend_seqs_len=。
- 测试配套:删除
test_qwen3_30b.py(原 CP 精度测试),修改 test_mla_cp_fa3_parity.py 适配新字段。
关键文件:
python/sglang/srt/layers/utils/cp_utils.py(模块 CP核心;类别 source;类型 core-logic;符号 cp_all_gather_reorganized_into_tensor): 核心变更文件,ContextParallelMetadata 数据类重构支持多 batch,can_cp_split 逻辑调整,cp_attn_forward_extend 适配多序列。
python/sglang/srt/layers/attention/dsa/dsa_indexer.py(模块 DSA索引器;类别 source;类型 core-logic): DSA 注意力路径适配,将标量字段改为从列表中取元素,仍假设 batch=1。
python/sglang/srt/models/deepseek_v2.py(模块 DeepSeek模型;类别 source;类型 data-contract): DeepSeek V2 模型入口,调整 prepare_context_parallel_metadata 参数名为 extend_seqs_len。
python/sglang/srt/models/deepseek_nextn.py(模块 NextN模型;类别 source;类型 data-contract): DeepSeek NextN 模型入口,同样调整参数名。
python/sglang/srt/managers/schedule_policy.py(模块 调度器;类别 source;类型 core-logic): 调度策略移除 prefill_context_parallel_enabled 限制,允许批处理多请求。
python/sglang/srt/model_executor/forward_batch_info.py(模块 批处理信息;类别 source;类型 data-contract): 调整 padding 对齐粒度从 attn_cp_size 改为 attn_cp_size * 2,保证 zigzag 负载均衡。
test/registered/cp/test_qwen3_30b.py(模块 CP测试;类别 test;类型 deletion;符号 TestQwen330B, setUpClass, tearDownClass, test_gsm8k): 删除原 CP 精度测试文件,可能以其他测试替代。
test/registered/kernels/test_mla_cp_fa3_parity.py(模块 MLA CP测试;类别 test;类型 test-coverage): 适配新的 ContextParallelMetadata 字段,构建正确的测试数据。
关键符号:can_cp_split, prepare_context_parallel_metadata, cp_attn_forward_extend, cp_all_gather_reorganized_into_tensor, add_one_req, prepare_mlp_sync_batch
关键源码片段
python/sglang/srt/layers/utils/cp_utils.py
核心变更文件,ContextParallelMetadata 数据类重构支持多 batch,can_cp_split 逻辑调整,cp_attn_forward_extend 适配多序列。
# ContextParallelMetadata 支持 batch size > 1
@dataclass
class ContextParallelMetadata:
# Layout lists have length bs * cp_segment_num (= bs * 2 * cp_size).
split_list: List[int] = None
zigzag_index: List[int] = None
cp_reverse_index: List[int] = None
reverse_split_len: List[int] = None
# Per-rank-aggregate lists have length cp_size.
per_rank_actual_token: List[int] = None
max_rank_len: List[int] = None
# Per-sequence FlashAttention tensors (shape [bs] or [bs+1]).
kv_len_prev_tensor: torch.Tensor = None # [bs] int32 CUDA
kv_len_next_tensor: torch.Tensor = None # [bs] int32 CUDA
actual_seq_q_prev_tensor: torch.Tensor = None # [bs] int32 CUDA
actual_seq_q_next_tensor: torch.Tensor = None # [bs] int32 CUDA
cu_seqlens_q_prev_tensor: torch.Tensor = None # [bs+1] int32 CUDA
cu_seqlens_q_next_tensor: torch.Tensor = None # [bs+1] int32 CUDA
# Per-seq CPU lists (useful for NSA indexer and diagnostics).
kv_len_prev_list: List[int] = None
kv_len_next_list: List[int] = None
actual_seq_q_prev_list: List[int] = None
actual_seq_q_next_list: List[int] = None
# Aggregate sum of extend_seq_lens across the batch.
total_seq_lens: int = 0
bs: int = 1
def can_cp_split(seq_len: int, cp_size: int, forward_batch):
# 基础条件 : CP 开启、size>1、纯 extend 模式
from sglang.srt.model_executor.forward_batch_info import ForwardMode
cur_cp_seq_len = seq_len // (cp_size * 2)
if not (
cur_cp_seq_len != 0
and cp_size > 1
and forward_batch.forward_mode.is_context_parallel_extend()
and forward_batch.forward_mode != ForwardMode.MIXED
and is_prefill_context_parallel_enabled()
):
return False
# 逐请求检查 extend length 是否足够 zigzag 切分
extend_lens = getattr(forward_batch, "extend_seq_lens_cpu", None)
if extend_lens is None:
return True
cp_min = cp_size * 2
for L in extend_lens:
if L < cp_min:
# 不满足切分条件的请求 gracefully 回退到非 CP 模式
return False
return True
python/sglang/srt/layers/attention/dsa/dsa_indexer.py
DSA 注意力路径适配,将标量字段改为从列表中取元素,仍假设 batch=1。
# head 版本 : 从 list 中取第一个元素,仍假设 batch=1 的 DSA 路径
if (
forward_batch.attn_cp_metadata is not None
and is_dsa_prefill_cp_in_seq_split()
):
kv_len_prev = forward_batch.attn_cp_metadata.kv_len_prev_list[0]
kv_len_next = forward_batch.attn_cp_metadata.kv_len_next_list[0]
actual_seq_q_prev = forward_batch.attn_cp_metadata.actual_seq_q_prev_list[0]
actual_seq_q_next = forward_batch.attn_cp_metadata.actual_seq_q_next_list[0]
# TODO: 支持 multi-batch 后需改为对应 batch 索引
评论区精华
风险与影响
- 风险:
- 核心路径回归:
ContextParallelMetadata 字段大幅变更,所有 CP 相关的注意力、前向、DP padding 逻辑均受影响,需关注 DeepSeek V3/V4、Qwen3 等模型的精度和 crash。
- CPU 开销增加:
prepare_context_parallel_metadata 中对每序列的逐循环计算可能成为 prefill 瓶颈,尤其当 bs 较大时。
- DSA 路径兼容性:
dsa_indexer.py 仍使用 kv_len_prev_list[0],在真正多 batch 前会限制 DSA CP 的 batch 能力。
- 测试覆盖缩减:删除了
test_qwen3_30b.py(原 CP 精度测试),若 test_mla_cp_fa3_parity.py 覆盖不足可能漏掉回归。
- 影响:
- 用户影响:启用 CP 时可获得更高 prefill 吞吐,但需要每个请求的 extend length 足够长(≥2×cp_size),否则回退到非 CP 模式。
- 系统影响:新增
attn_cp_size * 2 对齐计算,可能增加少量显存/计算开销;调度器不再强制 batch=1,可能改变并发行为。
- 团队影响:后续开发(如 DSA 多 batch、Triton 元数据 kernel)需要在此数据结构基础上继续迭代。
- 风险标记:核心路径变更, CPU开销增加, DSA路径兼容性, 测试覆盖缩减
关联脉络
- PR #23292 prepare_context_parallel_metadata 改动(未最终确定): kpham-sgl 在讨论中引用此 PR 作为参考,涉及 metadata 计算调整。
- PR #25821 nsa 替换为 dsa 的冲突引入: Fridge003 要求解决与此 PR 的冲突,将 'nsa' 替换为 'dsa'。
- PR #26380 [core] WAR barrier for overlap schedule buffer writes: 与本 PR 同属 CP 调度路径优化,后续有交互。
参与讨论