执行摘要
WanVAE 使用 Conv2d 原生宽度 padding
减少显存分配和 kernel 启动开销,通过利用 Conv2d 原生 padding 机制优化性能。PR body 提到 decode 阶段耗时从 5.2121s 降至 5.1284s(~1.5%),峰值显存从 44174MB 降至 43522MB(~1.5%)。
可精读,作为如何利用框架原生特性替代手动 pad 的案例。性能提升有限,但代码简洁性提升明显。
无 review 讨论,PR 由作者自行合并。
减少显存分配和 kernel 启动开销,通过利用 Conv2d 原生 padding 机制优化性能。PR body 提到 decode 阶段耗时从 5.2121s 降至 5.1284s(~1.5%),峰值显存从 44174MB 降至 43522MB(~1.5%)。
可精读,作为如何利用框架原生特性替代手动 pad 的案例。性能提升有限,但代码简洁性提升明显。
无 review 讨论,PR 由作者自行合并。
python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py 的 __init__ 中,当 height_halo_size > 0 时,_padding 从 (padding[1], padding[1], 0, 0) 改为 (0, 0, 0, 0),即不再为宽度方向保留 padding;同时将 self.padding 从 (0, 0) 改为 (0, self.padding[1]),将宽度方向的 padding 交由 nn.Conv2d 基类处理。当 height_halo_size == 0 时,_padding 的宽度部分也置零,高度部分保持不变。 forward 中,用 if any(self._padding): x = F.pad(x, self._padding) 替代了无条件的 x = F.pad(x, self._padding)。由于大部分情况下 _padding 为全零,跳过了无意义的 pad 操作。 | 文件 | 模块 | 状态 | 重要度 |
|---|---|---|---|
python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py |
扩散模型 | modified | 6.15 |
python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py
data-contract
唯一修改的文件,核心变更位于 WanDistConv2d 类的 __init__ 和 forward 方法,将宽度 padding 从显式 F.pad 迁移到 Conv2d 原生 padding。
# python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py
# 本 PR 将宽度方向 padding 从显式 F.pad 迁移到 Conv2d 原生 padding,
# 减少一次显存分配和 kernel 启动,提升 decode 阶段性能约 1.5%。
class WanDistConv2d(nn.Conv2d):
def __init__(...):
# ... 原有初始化逻辑 ...
# 旧代码:当 height_halo_size > 0 时,
# self._padding = (self.padding[1], self.padding[1], 0, 0)
# 新代码:置零宽度 padding,交由 Conv2d 基类处理
if self.height_halo_size > 0:
self._padding = (0, 0, 0, 0)
else:
self._padding = (
0,
0,
self.padding[0],
self.padding[0],
)
# 旧代码:self.padding = (0, 0) # 完全禁用 Conv2d 原生 padding
# 新代码:保留宽度方向的 padding,让 Conv2d 负责宽度填充
self.padding = (0, self.padding[1])
def forward(self, x):
# 旧代码:无条件调用 F.pad
# x = F.pad(x, self._padding)
# 新代码:仅当 _padding 非全零时调用(大多数情况可跳过)
if any(self._padding):
x = F.pad(x, self._padding)
x_padded, ... = halo_exchange(x, ...)
# ... 后续 forward 逻辑不变 ...
out = super().forward(x_padded)
return out
当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。
回归风险:虽然显式保证了 byte-identical,但若未来修改 Conv2d 的 padding 行为(例如换后端),宽度 padding 可能被重复应用(Conv2d 的 padding 和 _padding 中的宽度 padding)导致错误。当前 PR 通过将 _padding 宽度部分置零消除了此风险。
性能:warmup 后无需额外 kernel,正向性能提升稳定。
兼容性:仅影响 WanDistConv2d 内部逻辑,对外接口和高度 padding 语义未变。
影响范围:仅修改 WanDistConv2d 类,影响所有使用 WanVAE 的模型(包括 Cosmos3 等)的 decode 阶段。性能提升约 1-2%,显存降低约 1-2%。影响程度:低,属于微优化。
当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。
参与讨论