Prhub

#27440 [Diffusion] Avoid GPU syncs in UniPC scheduler

原始 PR 作者 BBuf 合并时间 2026-06-06 22:01 文件变更 1 提交数 1 评论 3 代码增减 +6 / -6

执行摘要

用 torch.stack 替换 torch.tensor 避免 GPU 同步

Cosmos3 H200 的 denoising profile 显示 aten::item / _local_scalar_dense / cudaStreamSynchronize 在 UniPC 调度器中重复出现(torch.tensor 构造时从 GPU 标量创建张量会触发 CPU 同步)。保持张量在设备上构建可以消除这些同步开销,同时保持数值等价。

值得精读,尤其关注如何在推理框架中通过简单代码替换消除隐式 GPU-CPU 同步。建议后续考虑采纳 reviewer 的 torch.ones_like 简化建议。

讨论亮点

Gemini Code Assist Bot 建议将 torch.ones((), dtype=h.dtype, device=h.device) 简化为 torch.ones_like(h),认为这样更简洁且自动推断 shape/dtype/device。该建议未被采纳但属于风格优化,不影响性能。

实现拆解

  1. 替换标量张量创建(所有 4 处):在 multistep_uni_p_bh_updatemultistep_uni_c_bh_update 两个方法中,将 rks.append(1.0) 改为 rks.append(torch.ones(() , dtype=h.dtype, device=h.device)),避免从 Python 浮点数构造时触发 CPU 同步。
  2. 替换列表到张量的转换(所有 2 处):将 torch.tensor(rks, device=device) 改为 torch.stack(rks).to(device=device),因为 rks 中的元素已经是 GPU 上的标量张量,用 stack 直接拼接而不需要隐式传递数据到 CPU。
  3. 替换系数列表 b 的构建(所有 2 处):将 torch.tensor(b, device=device) 改为 torch.stack(b).to(device=device),理由同上,b 中的元素也是 GPU 张量。
  4. 测试验证:通过 test/unit/test_cosmos3.py(37 passed),确保数值正确性。
文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_unipc_multistep.py 扩散调度器 modified 6.02

关键符号

multistep_uni_p_bh_update multistep_uni_c_bh_update

关键源码片段

python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_unipc_multistep.py performance

包含所有 6 处修改:用 torch.stack 替代 torch.tensor,并预创建设备上的标量张量以避免 GPU 同步。

# python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_unipc_multistep.py# 修改 1: multistep_uni_p_bh_update 中 rks 的构建
# 原 : rks.append(1.0) + torch.tensor(rks, device=device) -> 触发 CPU 同步
# 改 : 直接 append 一个 GPU 标量张量,然后用 torch.stack 拼接
rks.append(torch.ones((), dtype=h.dtype, device=h.device))
rks = torch.stack(rks).to(device=device)# 修改 2: multistep_uni_p_bh_update 中系数 b 的构建
# 原 : b = torch.tensor(b, device=device)
# 改 : 由于 b 中元素已是 GPU 张量,用 stack 避免隐式数据传输
b = torch.stack(b).to(device=device)# 修改 3 & 4: multistep_uni_c_bh_update 中完全相同模式的替换
rks.append(torch.ones((), dtype=h.dtype, device=h.device))
rks = torch.stack(rks).to(device=device)
b = torch.stack(b).to(device=device)

评论区精华

使用 torch.ones_like 简化标量创建 style

gemini-code-assist[bot] 建议将 `torch.ones((), dtype=h.dtype, device=h.device)` 替换为 `torch.ones_like(h)`,认为更简洁并自动推断属性。

结论:建议未被采用,但属于可选风格优化。 · unresolved

风险与影响

低风险:变更仅涉及将 Python 标量和 torch.tensor 替换为设备上已有张量的 stack 操作,逻辑等价。若 h 是标量张量(shape 为 ()),torch.ones_like(h)torch.ones((), ...) 行为一致。但需注意 rks 列表中元素数据类型:原 1.0 是 float,若 h 为 float 则无问题;若 h 为 half 或其他 dtype,需确保 torch.ones_like(h) 与之匹配。

直接影响 UniPC 调度器(扩散模型 denoising 过程),Cosmos3 模型去噪阶段加速约 1.8%(536ms 节省)。不影响 API 接口、数据格式或模型精度,仅内部实现优化。

微小边界情况风险(dtype 匹配)

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论