Prhub

#25674 [diffusion] Fix MOVA DAC bf16 on ROCm

原始 PR 作者 qimcis 合并时间 2026-05-22 15:18 文件变更 1 提交数 2 评论 6 代码增减 +16 / -3

执行摘要

修复 ROCm bf16 下 DAC Snake 编译失败

在 ROCm 平台上,DAC Snake 使用 torch.jit.script 编译 bf16 张量时,HIPRTC 编译失败。该问题在 PR #25411 的评论中被报告 (issuecomment-4476356483)。此 PR 旨在提供一个最小化修复,仅对 ROCm + bf16 场景回退到 eager 执行,避免 JIT 编译错误。

该 PR 值得精读,特别是对于需要支持多硬件平台(如 ROCm)的团队。其设计模式——将 JIT 编译的函数拆分为纯 Python 实现和编译赋值,并添加条件回退——是一种优雅的跨平台兼容性解决方案,值得在其他类似场景中借鉴。

讨论亮点

无争议性讨论。审核者 mickqian 直接批准了 PR,作者 qimcis 在 CI 通过后请求合并。没有 review 评论。

实现拆解

  1. 将核心 snake 计算提取为纯 Python 函数 _snake:将原 @torch.jit.script 装饰的函数体移至一个普通的 Python 函数 _snake,该函数负责实际的 Snake 激活计算。
  2. 保留 JIT 编译版本:通过 snake = torch.jit.script(_snake)_snake 编译为 TorchScript,并作为默认的快速路径。对于非 ROCm bf16 场景,性能不受影响。
  3. 新增条件判定函数 _should_use_eager_snake_on_rocm_bf16:该函数检查是否运行在 ROCm(torch.version.hip 非空)、张量是否在 GPU 上、以及数据是否为 bfloat16。
  4. Snake1d.forward 中引入分支:当条件满足时调用 _snake(x, self.alpha)(eager 模式),否则调用 snake(x, self.alpha)(JIT 加速版本)。
  5. 文件变更:仅修改了 python/sglang/multimodal_gen/runtime/models/vaes/dac.py 一个文件,增加 16 行,删除 3 行,属于对已有行为的最小侵入性修复。
文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/models/vaes/dac.py VAE 模型 modified 7.76

关键符号

_snake _should_use_eager_snake_on_rocm_bf16

关键源码片段

python/sglang/multimodal_gen/runtime/models/vaes/dac.py core-logic

唯一变更文件,核心修改包括拆分 snake 函数、新增条件判定逻辑、修改 Snake1d.forward 分支。

from torch import nn# 实现 Snake 激活函数的底层计算,以纯 Python 函数形式定义,
# 后续既可以用于 JIT 编译,也可以直接作为 eager 函数调用。
def _snake(x, alpha):
    shape = x.shape
    x = x.reshape(shape[0], shape[1], -1)
    x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
    x = x.reshape(shape)
    return x# Scripting this brings model speed up 1.4x
# 将 _snake 编译为 TorchScript,保留性能加速,作为默认路径。
snake = torch.jit.script(_snake)# ROCm HIPRTC can fail to compile the scripted bf16 Snake kernel.
# 检测当前环境是否为 ROCm 且张量为 bf16,若是则触发 eager 回退。
def _should_use_eager_snake_on_rocm_bf16(x: torch.Tensor, alpha: torch.Tensor) -> bool:
    return (
        torch.version.hip is not None # 是否运行在 ROCm 平台
        and (x.is_cuda or alpha.is_cuda) # 张量是否在 GPU 上
        and (x.dtype == torch.bfloat16 or alpha.dtype == torch.bfloat16) # 是否为 bfloat16 类型
    )
​
​
class Snake1d(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1, channels, 1))
​
    def forward(self, x):
        # 在 ROCm + bf16 条件下使用 eager 模式,避免 JIT 编译失败
        if _should_use_eager_snake_on_rocm_bf16(x, self.alpha):
            return _snake(x, self.alpha)
        # 其他路径走 JIT 编译加速版本
        return snake(x, self.alpha)

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

风险较低。修改仅影响 ROCm + bf16 场景下的 Snake 激活函数执行路径,对 NVIDIA 或非 bf16 路径无影响。JIT 路径完全保留,因此性能退化仅发生在受影响的场景。测试覆盖方面,没有新增单元测试,建议未来补充针对 ROCm bf16 的回归测试。

影响范围有限,仅作用于 ROCm 平台上使用 DAC 模型且输入为 bf16 的用户。修复了此前导致推理失败或崩溃的编译错误,使 ROCm 用户能够正常使用 DAC 模型。对其他平台和数据类型无影响。

缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论