Prhub

#23148 [codex] diffusion: enable group norm silu fuse by default

原始 PR 作者 BBuf 合并时间 2026-05-02 20:55 文件变更 5 提交数 9 评论 8 代码增减 +48 / -42

执行摘要

默认启用 GroupNorm+SiLU 融合以加速 HunyuanVideo VAE 解码

此前融合路径需通过环境变量手动启用,增加了用户部署成本。基准测试表明该优化稳定有效,应默认开启。PR body 明确说明:"This makes the CUDA GroupNorm+SiLU VAE fast path the default instead of requiring an environment variable."

值得精读,设计上采用安全 fallback 模式值得借鉴。review 中关于延迟导入的讨论也有参考价值。

讨论亮点

Review 中 gemini-code-assist[bot] 建议将延迟导入移出函数体以提高热路径性能,yuan-luo 表示同意,但最终代码仍保留了延迟导入(未采纳)。此外 review 整体通过了其余设计。

实现拆解

  1. 在 python/sglang/jit_kernel/diffusion/group_norm_silu.py 中新增 apply_group_norm_silu 函数,作为统一入口。它检查 x.is_cuda、梯度禁用、norm 为 GroupNorm 且 affine、activation 为 SiLU 且非 inplace 等条件,满足时调用 Triton 内核,否则回退到 activation(norm(x))。
  2. 在 python/sglang/multimodal_gen/runtime/models/vaes/hunyuanvae.py 中删除原有的 _apply_hunyuan_group_norm_silu 函数和 envs 导入,替换为 apply_group_norm_silu,并将调用扩展到 encoder/decoder 的 conv_norm_out 和 conv_act。
  3. 在 python/sglang/multimodal_gen/envs.py 中删除 SGLANG_USE_CUDA_HUNYUANVIDEO_GROUP_NORM_SILU 的定义和 lazy_bool 注册,该环境变量不再可用。
  4. 在 python/sglang/jit_kernel/diffusion/triton/group_norm_silu.py 的 _can_use_triton_group_norm_silu 中添加 not torch.is_grad_enabled() 检查,确保梯度启用时不会误用内核。
  5. 在 python/sglang/jit_kernel/tests/diffusion/test_group_norm_silu.py 中将 test_apply_hunyuan_group_norm_silu 重命名为 test_apply_group_norm_silu,移除 monkeypatch 环境变量的操作,直接验证通用函数。
文件 模块 状态 重要度
python/sglang/jit_kernel/diffusion/group_norm_silu.py JIT 内核 added 7.54
python/sglang/multimodal_gen/runtime/models/vaes/hunyuanvae.py 扩散模型 modified 7.24
python/sglang/jit_kernel/tests/diffusion/test_group_norm_silu.py JIT 内核 modified 5.32
python/sglang/multimodal_gen/envs.py 环境配置 modified 4.39
python/sglang/jit_kernel/diffusion/triton/group_norm_silu.py JIT 内核 modified 3.78

关键符号

apply_group_norm_silu _apply_hunyuan_group_norm_silu test_apply_group_norm_silu

关键源码片段

python/sglang/jit_kernel/diffusion/group_norm_silu.py core-logic

新增的通用辅助函数,核心逻辑所在

import torch
from torch import nndef apply_group_norm_silu(
    x: torch.Tensor,
    norm: nn.Module,
    activation: nn.Module,
) -> torch.Tensor:
    # 仅当输入在 CUDA 上、梯度未启用、且 norm 和 activation 符合条件时,才使用 Triton 融合内核
    if (
        x.is_cuda
        and not torch.is_grad_enabled()
        and not x.requires_grad
        and isinstance(norm, nn.GroupNorm)
        and isinstance(activation, nn.SiLU)
        and not activation.inplace
        and norm.affine
        and norm.weight is not None
        and norm.bias is not None
    ):
        # 延迟导入,避免模块初始化开销(但会在热路径上每次调用进行模块查找)
        from sglang.jit_kernel.diffusion.triton.group_norm_silu import (
            triton_group_norm_silu,
        )
​
        return triton_group_norm_silu(
            x,
            norm.weight,
            norm.bias,
            num_groups=norm.num_groups,
            eps=norm.eps,
        )
    # 对于不支持的配置(CPU、梯度模式、非 SiLU、非 affine 等),回退到 PyTorch 原生路径
    return activation(norm(x))
python/sglang/multimodal_gen/runtime/models/vaes/hunyuanvae.py data-contract

模型侧集成,替换原有专用函数

# ResBlock 的 forward 方法中,使用通用 apply_group_norm_silu 替换原有专用调用
class HunyuanVideoResBlock3D(nn.Module):
    # ...
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = hidden_states.contiguous()
        residual = hidden_states
​
        # 原为 _apply_hunyuan_group_norm_silu,现统一使用 apply_group_norm_silu
        hidden_states = apply_group_norm_silu(
            hidden_states, self.norm1, self.nonlinearity
        )
        hidden_states = self.conv1(hidden_states)
​
        hidden_states = apply_group_norm_silu(
            hidden_states, self.norm2, self.nonlinearity
        )
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)
​
        if self.conv_shortcut is not None:
            residual = self.conv_shortcut(residual)
​
        hidden_states = hidden_states + residual
        return hidden_states

评论区精华

延迟导入性能优化 性能

gemini-code-assist[bot] 建议将函数内的延迟导入移到模块顶层,以避免热路径上的符号查找开销。yuan-luo 表示同意此观点。

结论:未采纳,代码仍使用延迟导入,但合并后其余部分得到认可。 · unresolved

风险与影响

主要风险:移除环境变量后用户无法手动禁用融合路径,但自动 fallback 机制保证了不支持的配置(如 CPU、梯度模式、非 affine)仍使用原生路径;延迟导入在每次调用时都有符号查找开销,但通常影响微小;Triton 内核依赖 CUDA,非 CUDA 设备自动降级。

对 HunyuanVideo VAE 推理用户有显著性能收益,解码耗时降低约 17%,无需手动配置。团队移除了环境变量,降低了配置复杂度。

移除环境变量后无法显式关闭融合 延迟导入在热路径上的微小开销

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论