Prhub

#27081 [diffusion] Use Conv2d width padding in WanVAE

原始 PR 作者 mickqian 合并时间 2026-06-03 10:14 文件变更 1 提交数 1 评论 2 代码增减 +6 / -5

执行摘要

WanVAE 使用 Conv2d 原生宽度 padding

减少显存分配和 kernel 启动开销,通过利用 Conv2d 原生 padding 机制优化性能。PR body 提到 decode 阶段耗时从 5.2121s 降至 5.1284s(~1.5%),峰值显存从 44174MB 降至 43522MB(~1.5%)。

可精读,作为如何利用框架原生特性替代手动 pad 的案例。性能提升有限,但代码简洁性提升明显。

讨论亮点

无 review 讨论,PR 由作者自行合并。

实现拆解

  1. 修改 WanDistConv2d 初始化中的 padding 分配逻辑:在 python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py__init__ 中,当 height_halo_size > 0 时,_padding(padding[1], padding[1], 0, 0) 改为 (0, 0, 0, 0),即不再为宽度方向保留 padding;同时将 self.padding(0, 0) 改为 (0, self.padding[1]),将宽度方向的 padding 交由 nn.Conv2d 基类处理。当 height_halo_size == 0 时,_padding 的宽度部分也置零,高度部分保持不变。
  2. 优化 forward 中的 F.pad 调用:在 forward 中,用 if any(self._padding): x = F.pad(x, self._padding) 替代了无条件的 x = F.pad(x, self._padding)。由于大部分情况下 _padding 为全零,跳过了无意义的 pad 操作。
  3. 验证正向一致性:输出 byte-identical(MP4 sha256 一致),确保功能无退化。
文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py 扩散模型 modified 6.15

关键符号

WanDistConv2d.__init__ WanDistConv2d.forward

关键源码片段

python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py data-contract

唯一修改的文件,核心变更位于 WanDistConv2d 类的 __init__ 和 forward 方法,将宽度 padding 从显式 F.pad 迁移到 Conv2d 原生 padding。

# python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py
# 本 PR 将宽度方向 padding 从显式 F.pad 迁移到 Conv2d 原生 padding,
# 减少一次显存分配和 kernel 启动,提升 decode 阶段性能约 1.5%。class WanDistConv2d(nn.Conv2d):
    def __init__(...):
        # ... 原有初始化逻辑 ...
​
        # 旧代码:当 height_halo_size > 0 时,
        # self._padding = (self.padding[1], self.padding[1], 0, 0)
        # 新代码:置零宽度 padding,交由 Conv2d 基类处理
        if self.height_halo_size > 0:
            self._padding = (0, 0, 0, 0)
        else:
            self._padding = (
                0,
                0,
                self.padding[0],
                self.padding[0],
            )
​
        # 旧代码:self.padding = (0, 0) # 完全禁用 Conv2d 原生 padding
        # 新代码:保留宽度方向的 padding,让 Conv2d 负责宽度填充
        self.padding = (0, self.padding[1])
​
    def forward(self, x):
        # 旧代码:无条件调用 F.pad
        # x = F.pad(x, self._padding)
        # 新代码:仅当 _padding 非全零时调用(大多数情况可跳过)
        if any(self._padding):
            x = F.pad(x, self._padding)
​
        x_padded, ... = halo_exchange(x, ...)
        # ... 后续 forward 逻辑不变 ...
        out = super().forward(x_padded)
        return out

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

回归风险:虽然显式保证了 byte-identical,但若未来修改 Conv2d 的 padding 行为(例如换后端),宽度 padding 可能被重复应用(Conv2d 的 padding 和 _padding 中的宽度 padding)导致错误。当前 PR 通过将 _padding 宽度部分置零消除了此风险。
性能:warmup 后无需额外 kernel,正向性能提升稳定。
兼容性:仅影响 WanDistConv2d 内部逻辑,对外接口和高度 padding 语义未变。

影响范围:仅修改 WanDistConv2d 类,影响所有使用 WanVAE 的模型(包括 Cosmos3 等)的 decode 阶段。性能提升约 1-2%,显存降低约 1-2%。影响程度:低,属于微优化。

微优化 高度特定于 WanVAE

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论