Prhub

#27143 [diffusion] Batch USP replicated KV prefix all-to-all

原始 PR 作者 mickqian 合并时间 2026-06-03 21:22 文件变更 1 提交数 2 评论 2 代码增减 +25 / -3

执行摘要

批量 CUDA A2A 通信优化扩散模型推理

现有 replicated KV 前缀分割路径为 Q、K、V 分别调用 _usp_input_all_to_all,导致三次 A2A 通信;利用 CUDA 的异步 A2A 可合并为一次批量通信,提升 GPU 利用率并减少延迟。

建议在合并后尽快进行多 GPU 扩散模型的集成测试,验证正确性和性能收益。该 PR 的设计决策(合并通信与共享 stream)值得后续类似优化参考,但风险较低,可部署。

讨论亮点

gemini-code-assist[bot] 建议将每实例的 CUDA stream 改为类级共享,避免模型层数多时创建过多 stream(如 32+),以减少资源开销并避免 CUDA Graph 捕获问题。该建议被采纳并在第二个提交中实现。

实现拆解

  1. 新增导入与类级 stream:在 layer.py 中导入 async_a2a_communicate,在 USPAttention 类中添加类变量 _usp_a2a_stream 和方法 _get_usp_a2a_stream,共享同一个 CUDA stream 以避免每实例创建过多 stream。
  2. 修改 _forward_with_replicated_kv_prefix_split:当输入在 CUDA 设备上时,用 async_a2a_communicate([q, k_shard, v_shard], ...) 批量执行 A2A,并对输出调用 .contiguous() 以确保后续注意力操作内存连续;非 CUDA 设备保持原有三次独立调用。
  3. 提交历史与反馈响应:首个提交实现了批量 A2A;第二个提交根据 Review 建议将 stream 从实例级改为类级共享,优化资源利用。
文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/layers/attention/layer.py 注意力层 modified 7.22

关键符号

_get_usp_a2a_stream _forward_with_replicated_kv_prefix_split

关键源码片段

python/sglang/multimodal_gen/runtime/layers/attention/layer.py core-logic

核心变更文件,修改了 USPAttention 类的 stream 管理以及 replicated KV 前缀分割路径的通信方式。

# class USPAttention (partial, focused on stream + batch A2A changes)class USPAttention(nn.Module):
    _usp_a2a_stream = None # 类级共享 stream,避免每实例创建过多 stream
​
    def __init__(self, ...):
        ...
        self.skip_sequence_parallel = skip_sequence_parallel
​
    def _get_usp_a2a_stream(self):
        # 懒初始化:所有 USPAttention 实例共享同一个 CUDA stream
        if USPAttention._usp_a2a_stream is None:
            USPAttention._usp_a2a_stream = torch.get_device_module().Stream()
        return USPAttention._usp_a2a_stream
​
    def _forward_with_replicated_kv_prefix_split(self, q, k_shard, v_shard):
        """split form avoids materializing full K/V before Ulysses all-to-all"""
        sp_rank = get_sp_parallel_rank()
​
        # CUDA 路径:批量异步 A2A + contiguous 确保内存连续
        if q.device.type == "cuda":
            q, k_shard, v_shard = async_a2a_communicate(
                [q, k_shard, v_shard],
                get_ulysses_parallel_world_size(),
                get_sp_group().ulysses_group,
                self._get_usp_a2a_stream(),
                local_seq_2_local_head=True,
            )
            q = q.contiguous()
            k_shard = k_shard.contiguous()
            v_shard = v_shard.contiguous()
        else:
            # 非 CUDA 回退:保持原有三次独立 A2A 调用
            q = _usp_input_all_to_all(q, head_dim=2)
            k_shard = _usp_input_all_to_all(k_shard, head_dim=2)
            v_shard = _usp_input_all_to_all(v_shard, head_dim=2)
​
        h_kv_local = k_shard.shape[2]
        h_start = sp_rank * h_kv_local
        ...

评论区精华

CUDA stream 共享建议 性能

gemini-code-assist[bot] 建议将每实例的 CUDA stream 改为类级共享,避免模型层数多时创建过多 stream 并防止 CUDA Graph 捕获问题。

结论:采纳建议,在第二个提交中将 stream 改为类级共享(`_usp_a2a_stream` 类变量)。 · 已解决

风险与影响

主要风险在于 CUDA 路径的行为变更:async_a2a_communicate 是否完全等价于三次顺序 A2A,特别是涉及序列维度和 head 维度的切分。local_seq_2_local_head=True 参数的正确性需要多 GPU 测试验证。另外,.contiguous() 会引入显存拷贝,在极端序列长度下可能抵消并行收益。本次缺少对应的单元测试或集成测试覆盖此路径。

影响范围限于扩散模型的 CUDA 多 GPU 推理路径,特别是使用 Ulysses 序列并行且 KV 复制的跨注意力场景。预期将减少通信延迟,提升 GPU 利用率,但非 CUDA 后端(如 CPU)不受影响。

缺少多 GPU 测试覆盖 CUDA 路径与非 CUDA 路径行为差异

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论