Prhub

#25256 [MUSA][Diffusion] Improve wan model inference speed using torch.compile

原始 PR 作者 wenqf11 合并时间 2026-05-17 22:10 文件变更 4 提交数 5 评论 4 代码增减 +15 / -1

执行摘要

torch.compile 加速 Wan 扩散模型推理

提升 Wan 模型在 MUSA 设备(以及 CUDA)上的推理速度,通过 torch.compile 优化关键操作。PR 正文引用:'Improve Wan model inference using torch.compile on MUSA device, also on CUDA.' 并提供了在 S5000 和 H200 上的单步加速数据(分别为 1.09x 和 1.05x)。

该 PR 展示了在扩散模型推理中应用 torch.compile 的典型模式:添加平台回退、新增 forward_xxx 方法、批量装饰原生方法。值得学习其平台分派和条件编译策略。改动虽小但提供清晰 benchmark,适合作为性能优化的参考案例。

讨论亮点

yeahdongcn 在 wanvideo.py 的评论中建议将 .contiguous() 放在 is_musa 条件下,以避免对其他平台产生不必要的内存拷贝。wenqf11 回复展示截图说明即使不加条件,后续的 LayerNormScaleShift 层也会自动做 contiguous,因此添加 contiguity 没有实际影响(no effect)。最终未添加条件控制,PR 经 yeahdongcn 和 mickqian 批准合并。

实现拆解

  1. MUSA 平台回退:在 python/sglang/jit_kernel/diffusion/triton/scale_shift.py 中添加 current_platform.is_musa() 分支,使 fuse_scale_shift_kernel 在 MUSA 上使用 torch_fallback
  2. MulAdd 编译优化:在 python/sglang/multimodal_gen/runtime/layers/elementwise.py 中为 MulAdd 新增 forward_musa 方法,用 @torch.compile 装饰并直接调用 forward_native,使得 MUSA 设备上的 MulAdd 操作被编译优化。
  3. Layernorm 编译优化:在 python/sglang/multimodal_gen/runtime/layers/layernorm.py 中对三个 forward_native 方法(_ScaleResidualNormScaleShift.forward_native_NormScaleShift.forward_native_NormTanhMulAdd.forward_native)添加 @torch.compile(disable=current_platform.is_npu()) 装饰,避免影响 NPU 路径。
  4. 张量连续性保证:在 python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py 中 patch_embedding 后添加 .contiguous() 确保内存连续,满足 torch.compile 要求(虽然后续层也会自动处理,但无额外开销)。
  5. 性能验证:提供了 MUSA S5000 和 Nvidia H200 上的 profiler 对比截图,未新增单元测试,但 CI 包含现有集成测试。
文件 模块 状态 重要度
python/sglang/jit_kernel/diffusion/triton/scale_shift.py JIT 内核 modified 5.87
python/sglang/multimodal_gen/runtime/layers/elementwise.py 算子层 modified 6.1
python/sglang/multimodal_gen/runtime/layers/layernorm.py 算子层 modified 5.22
python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py 模型 modified 4.7

关键符号

MulAdd.forward_musa _ScaleResidualNormScaleShift.forward_native _NormScaleShift.forward_native _NormTanhMulAdd.forward_native

关键源码片段

python/sglang/jit_kernel/diffusion/triton/scale_shift.py dependency-wiring

添加 MUSA 平台的回退路径,是 MUSA 支持的基础。所有 diffusion 模型中共用的 scale_shift 内核选择逻辑扩展至此。

# 在已有平台回退后添加 MUSA 分支
if current_platform.is_npu():
    from .npu_fallback import fuse_scale_shift_native
    fuse_scale_shift_kernel = fuse_scale_shift_nativeif current_platform.is_mps():
    from .mps_fallback import fuse_scale_shift_kernel_native
    fuse_scale_shift_kernel = fuse_scale_shift_kernel_native# 新增:MUSA 平台使用 torch 原生实现(回退)
if current_platform.is_musa():
    from .torch_fallback import fuse_scale_shift_kernel_native
    fuse_scale_shift_kernel = fuse_scale_shift_kernel_nativeif current_platform.is_cpu():
    from .torch_fallback import (
        fuse_scale_shift_kernel_native,
    )
    fuse_scale_shift_kernel = fuse_scale_shift_kernel_native
python/sglang/multimodal_gen/runtime/layers/elementwise.py core-logic

新增 forward_musa 方法,是 torch.compile 在 MulAdd 算子的直接应用入口。

class MulAdd(CustomOp):
    # ... 其他方法 ...
​
    def forward_xpu(self, a, b, c, k=0):
        return self.forward_native(a, b, c, k=k)
​
    # 新增:MUSA 专用编译版本
    @torch.compile
    def forward_musa(
        self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, k: int = 0
    ):
        # 直接复用原生实现,torch.compile 负责图优化
        return self.forward_native(a, b, c, k=k)
​
    def forward_npu(self, a, b, c, k=0):
        from sgl_kernel_npu.norm.scale_shift import fused_scale_shift
        return fused_scale_shift(a, b, c, scale_constant=k)
python/sglang/multimodal_gen/runtime/layers/layernorm.py core-logic

对三个核心 Layernorm 变体的 forward_native 启用 torch.compile,影响所有使用这些层的扩散模型。同时通过条件 disable 保护 NPU 路径。

class _ScaleResidualNormScaleShift(CustomOp):
    # ... 其他方法 ...
​
    def forward_xpu(self, *args, **kwargs):
        return self.forward_native(*args, **kwargs)
​
    # 启用编译,但 NPU 平台除外(disable=True 会跳过编译)
    @torch.compile(disable=current_platform.is_npu())
    def forward_native(
        self,
        residual: torch.Tensor,
        x: torch.Tensor,
        gate: torch.Tensor | int,
        shift: torch.Tensor,
        scale: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # 原有实现:残差连接 + layernorm + scale-shift
        # ... ( 省略主体,保持不变 )
        return modulated, residual_outputclass _NormScaleShift(CustomOp):
    @torch.compile(disable=current_platform.is_npu())
    def forward_native(self, x, shift, scale):
        # ...
        return modulated.to(x.dtype)class _NormTanhMulAdd(CustomOp):
    @torch.compile(disable=current_platform.is_npu())
    def forward_native(self, x, scale, shift):
        # ...
        return result

评论区精华

contiguous 是否应仅用于 MUSA 设计

yeahdongcn 建议将 wanvideo.py 中的 .contiguous() 限制在 is_musa 条件下,基于 offline discussion 认为仅 MUSA 需要。wenqf11 回复展示截图说明后续 LayerNormScaleShift 会自动做 contiguous,因此加不加条件没有实际影响。

结论:未修改,保持无条件 contiguous。PR 经批准合并,说明 reviewer 接受了不影响性能的解释。 · 已解决

风险与影响

  • 性能风险:torch.compile 首次调用有编译开销,适合多步推理;单步加速较小(5-9%),但无负面副作用。
  • 平台兼容:MUSA 使用 torch_fallback,不触发 Triton 内核,正确性有保证但性能提升有限。layernorm 上的 @torch.compile 可通过 disable=is_npu() 排除 NPU,不影响现有 NPU 路径。
  • 内存开销:wanvideo.py 中额外 .contiguous() 增加一次复制,但后续层也会做,不会叠加。
  • 测试覆盖:未针对 torch.compile 编写特定测试,依赖现有集成测试,可能遗漏编译错误或目标平台差异。
  • 用户影响:Wan 模型用户可获得小幅推理加速,无需任何配置变更。
  • 系统影响:对非 MUSA/CUDA 平台无影响;NPU 平台通过条件禁用避免问题。
  • 团队影响:简化了 MUSA 平台支持,为未来其他算子使用 torch.compile 提供了参考模式。
  • 影响范围:所有使用 MulAdd、LayernormScaleShift、RMSNormScaleShift、NormTanhMulAdd 的扩散模型(不限于 Wan)都会在 CUDA/MUSA 上启用 torch.compile,layernorm 的编译优化对 CUDA 也适用(设计上兼容)。
平台特定回退 编译开销 无专用测试 NPU 条件禁用

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论