Prhub

#24332 [Codex] Diffusion handle non-contiguous CFG communication

原始 PR 作者 BBuf 合并时间 2026-05-06 17:27 文件变更 2 提交数 3 评论 3 代码增减 +4 / -0

执行摘要

修复扩散模型 CFG 并行中非连续张量通信崩溃

修复 jdopensource/JoyAI-Image-Edit-Diffusers 模型在 --num-gpus=2 自动启用 CFG 并行时,去噪阶段因非连续张量调用 torch.distributed.all_reducebroadcast 而崩溃的问题。PR body 中描述了错误信息 ValueError: Tensors must be contiguous 和具体触发点。

建议合并此 PR,因为它修复了 CPG 并行在 JoyAI 等模型上的功能性崩溃,并带来了显著的性能提升。但在合并前,应评估 review 中提出的 in-place 语义问题——如果调用者依赖原始张量更新,需复制回结果(如 input_.copy_(contiguous_result));若当前所有调用者都不依赖,则可忽略。建议补充一个单元测试用例,覆盖非连续张量输入场景。

讨论亮点

review 评论来自 gemini-code-assist[bot],指出在 cfg_model_parallel_all_reduce 中对非连续输入调用 contiguous() 会创建副本,破坏了对原始张量原地修改的语义(all_reduce 通常期望 in-place),可能导致调用者继续使用未规约的原始张量。建议将结果复制回原始张量以保持行为一致。该评论尚未被作者回复或解决。

实现拆解

  1. cfg_model_parallel_all_reduce 中添加连续化检查:在 python/sglang/multimodal_gen/runtime/distributed/communication_op.pycfg_model_parallel_all_reduce 函数中,在调用底层 all_reduce 前检查 input_.is_contiguous(),若非连续则通过 input_.contiguous() 创建连续副本,确保底层分布式调用成功。
  2. broadcast 中添加连续化检查:在 python/sglang/multimodal_gen/runtime/distributed/group_coordinator.pybroadcast 方法中,同样在调用 torch.distributed.broadcast 前对非连续张量进行 contiguous() 转换。
  3. 保持快速路径不变:仅当张量非连续时才执行额外复制,不影响已有连续张量的性能。
文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/distributed/communication_op.py 分布式通信 modified 5.75
python/sglang/multimodal_gen/runtime/distributed/group_coordinator.py 分布式通信 modified 5.75

关键符号

cfg_model_parallel_all_reduce GroupCoordinator.broadcast

关键源码片段

python/sglang/multimodal_gen/runtime/distributed/communication_op.py core-logic

修改 `cfg_model_parallel_all_reduce`,在调用底层 all_reduce 前对非连续张量执行 contiguous(),修复 CFG 并行中的崩溃。review 评论对此处的 in-place 语义提出风险。

# python/sglang/multimodal_gen/runtime/distributed/communication_op.pydef cfg_model_parallel_all_reduce(
    input_: torch.Tensor,
    op: torch._C._distributed_c10d.ReduceOp = torch._C._distributed_c10d.ReduceOp.SUM,
) -> torch.Tensor:
    """All-reduce the input tensor across CFG parallel group."""
    # 修复:底层 all_reduce 要求连续张量,非连续时显式转为连续
    # 注意:contiguous() 可能创建副本,破坏 in-place 语义
    if not input_.is_contiguous():
        input_ = input_.contiguous()
    return get_cfg_group().all_reduce(input_, op=op)
python/sglang/multimodal_gen/runtime/distributed/group_coordinator.py core-logic

修改 `broadcast` 方法,添加与 all_reduce 相同的 contiguous() 保护,修复 JoyAI 模型在 broadcast 上的崩溃。

# python/sglang/multimodal_gen/runtime/distributed/group_coordinator.pyclass GroupCoordinator:
    # ...
    def broadcast(self, input_: torch.Tensor, src: int = 0, async_op: bool = False):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"
​
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
        # 修复:底层 broadcast 要求连续张量,非连续时显式转为连续
        if not input_.is_contiguous():
            input_ = input_.contiguous()
        torch.distributed.broadcast(
            input_,
            src=self.ranks[src],
            group=self.device_group,
            async_op=async_op,
        )
        return input_

评论区精华

all_reduce 非连续张量 contiguity 破坏 in-place 语义 正确性

gemini-code-assist[bot] 指出在 cfg_model_parallel_all_reduce 中,对非连续输入使用 contiguous() 创建副本后,规约在副本上进行,返回副本,而原始张量未更新。若调用者依赖原始张量原地更新,则会得到未规约的数据。建议将结果复制回原始张量。

结论:未解决。PR 作者未回复或修改代码。 · unresolved

风险与影响

  1. 语义破坏风险all_reduce 文档和习惯用法支持原地修改,但当前实现仅在非连续时返回新张量,调用者若依赖原始张量更新将得到未规约结果。这可能影响依赖 in-place 行为的代码路径(如后续张量共享引用)。
  2. 性能影响:contiguous() 调用会触发内存复制和重新排列,对非连续大张量有额外开销。但仅当非连续时触发,且 CFG 并行本身较少见非连续场景,风险可控。
  3. 测试覆盖不足:源码变更未附带新测试用例,仅依赖现有单元测试和手工验证。

用户影响:修复了 JoyAI 等扩散模型在 CFG 并行模式下的崩溃,使这些模型能正常使用自动 CFG 并行加速。根据 PR 数据,对 JoyAI 可带来 24-27% 的端到端延迟改善。
系统影响:改动集中在两个通信函数,影响范围仅限于扩散模型的 CFG 并行路径,不涉及序列并行或其他通信原语。
团队影响:小型改动(4 行增加),易于审查和合并。

语义破坏风险 缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论