Prhub

#27151 [diffusion] Skip unused WanVAE halo send copies

原始 PR 作者 mickqian 合并时间 2026-06-04 10:23 文件变更 1 提交数 4 评论 4 代码增减 +25 / -5

执行摘要

跳过边界 rank 的 WanVAE halo 发送副本

在 WanVAE 分布式解码中,halo 交换操作在边界 rank(rank=0 或 rank=world_size-1)上仍会为发送创建 contiguouis 副本,但这些副本实际上不会被使用(因为边界 rank 没有对应的发送目标)。同时,使用 torch.empty_like 分配接收缓冲区可能不匹配输入张量的 channels_last 或 channels_last_3d 格式,导致后续 concat 触发昂贵的布局转换。

值得精读的实现级优化,展示了如何通过内存格式感知来避免分布式推理中的显式/隐式数据副本。_halo_memory_format 的检测模式可推广到其他分布式卷积/注意力模块。

讨论亮点

AI 审核机器人指出两个关键问题:

  • 使用 torch.empty 默认分配 contiguous 缓冲区,若输入张量是 channels_last_3d 格式,后续 concat 会触发昂贵的布局转换。
  • 发送张量 .contiguous() 默认转为 contiguous 格式,而接收缓冲区可能是 channels_last_3d 格式,导致 P2P 通信时数据损坏。
    PR 作者在后续提交中新增 _halo_memory_format 函数动态检测内存格式,并在 _ensure_recv_buf 和发送前 contiguous 调用中传入该格式,解决了上述问题。

实现拆解

  1. 新增 _halo_memory_format 函数wan_dist_utils.py:106-112): 根据参考张量的维度和步幅动态推断最佳内存格式(channels_last_3d / channels_last / contiguous_format),用于分配接收缓冲区和控制发送张量的布局。
  2. 改造 _ensure_recv_bufwan_dist_utils.py:124-141): 将 torch.empty_like(reference) 替换为 torch.empty(..., memory_format=memory_format),并增加 is_contiguous(memory_format) 检查,若现有缓存不匹配则重新分配,避免后续 concat 的性能损失。
  3. 优化 halo_exchange 函数wan_dist_utils.py:144-197): 将 top_row/bottom_row = x[...].contiguous() 分解为 top_row_ref/bottom_row_ref = x[...](延迟切片),仅在非边界 rank 上发送时才调用 .contiguous(memory_format=...) 创建副本,跳过边界 rank 的无用发送副本。
文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py 分布式基础 modified 7.16

关键符号

_halo_memory_format _ensure_recv_buf halo_exchange

关键源码片段

python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py data-contract

唯一修改文件,核心变更:新增 `_halo_memory_format`、重构 `_ensure_recv_buf` 和 `halo_exchange` 以支持内存格式感知的缓冲区分配与延迟副本创建。

# wan_dist_utils.py 关键变更:内存格式感知的 halo 交换def _halo_memory_format(reference: torch.Tensor) -> torch.memory_format:
    # 根据参考张量的内存布局,推断接收 / 发送缓冲区应使用的格式
    if reference.dim() > 1 and reference.stride(1) == 1: # 末尾维度连续 => 可能是 channels_last
        if reference.dim() == 5 and hasattr(torch, "channels_last_3d"):
            return torch.channels_last_3d
        if reference.dim() == 4:
            return torch.channels_last
    return torch.contiguous_format # 默认 contiguous
​
​
def _ensure_recv_buf(
    recv_buf: torch.Tensor | None, reference: torch.Tensor
) -> torch.Tensor:
    # 动态检测参考张量的内存格式,用于分配接收缓冲区
    memory_format = _halo_memory_format(reference)
    if (
        recv_buf is None
        or recv_buf.shape != reference.shape
        or recv_buf.dtype != reference.dtype
        or recv_buf.device != reference.device
        or not recv_buf.is_contiguous(memory_format=memory_format)
    ):
        # 用推断的格式创建空张量,避免后续 concat 隐式转换
        return torch.empty(
            reference.shape,
            dtype=reference.dtype,
            device=reference.device,
            memory_format=memory_format,
        )
    return recv_buf
​
​
def halo_exchange(
    x: torch.Tensor,
    height_halo_size: int = 1,
    recv_top_buf: torch.Tensor | None = None,
    recv_bottom_buf: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    if height_halo_size == 0:
        return x, recv_top_buf, recv_bottom_buf
    # ... 省略 group/rank 初始化 ...
    # 延迟切片,避免立即创建 contiguous 副本
    top_row_ref = x[..., :height_halo_size, :]
    bottom_row_ref = x[..., -height_halo_size:, :]
​
    recv_top_buf = _ensure_recv_buf(recv_top_buf, top_row_ref)
    recv_bottom_buf = _ensure_recv_buf(recv_bottom_buf, bottom_row_ref)
​
    p2p_ops = []
    if rank > 0:
        # 非首边界:发送时创建与接收缓冲区同格式的 contiguous 副本
        prev_rank = group_ranks[rank - 1]
        top_row = top_row_ref.contiguous(memory_format=_halo_memory_format(top_row_ref))
        p2p_ops.append(dist.P2POp(dist.irecv, recv_top_buf, prev_rank, group))
        p2p_ops.append(dist.P2POp(dist.isend, top_row, prev_rank, group))
    if rank < world_size - 1:
        next_rank = group_ranks[rank + 1]
        bottom_row = bottom_row_ref.contiguous(
            memory_format=_halo_memory_format(bottom_row_ref)
        )
        p2p_ops.append(dist.P2POp(dist.isend, bottom_row, next_rank, group))
        p2p_ops.append(dist.P2POp(dist.irecv, recv_bottom_buf, next_rank, group))
    # 边界 rank 的接收缓冲区直接置零(无发送操作)
    if rank == 0:
        recv_top_buf.zero_()
    if rank == world_size - 1:
        recv_bottom_buf.zero_()
    # ... 执行 batch P2P 并 concat ...

评论区精华

内存格式不一致导致性能 / 正确性问题 性能

AI 审核指出:使用 `torch.empty` 默认分配 contiguous 缓冲区,若输入是 channels_last_3d 格式,后续 concat 会触发昂贵的布局转换;发送张量使用无参数的 `.contiguous()` 默认为 contiguous,而接收缓冲区可能为 channels_last_3d 格式,P2P 通信时数据损坏。

结论:PR 作者通过新增 `_halo_memory_format` 函数,在分配接收缓冲区和发送 contiguous 时都传入推断的格式,保证布局一致。 · 已解决

风险与影响

核心风险集中在 _halo_memory_format 函数对内存格式的推断逻辑:对于高维张量(dim>5)或非标准步幅的输入,返回 contiguous_format 可能并非最优,但功能正确。P2P 通信和 concat 路径的布局现已同步,回归风险低。未同步修改其他可能调用 _ensure_recv_bufhalo_exchange 的模块(目前仅 WanVAE 使用)。

影响范围限于 WanVAE 分布式解码路径(wan_dist_utils.py)。性能影响:边界 rank 节省两次 contiguous 副本创建;所有 rank 的接收缓冲区布局与输入一致,避免 concat 隐式转换。功能上保持输出 bit-exact 一致(PR 验证了 MP4 sha 匹配)。对用户透明,无配置变更。

核心路径变更 缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论