Prhub

#24313 [diffusion] chore: align LTX-2 with official

原始 PR 作者 mickqian 合并时间 2026-05-07 08:46 文件变更 11 提交数 74 评论 2 代码增减 +369 / -208

执行摘要

对齐 LTX-2 与官方实现的注意力语义和数值精度

Align native LTX text-encoder attention behavior with the official implementation while preserving high-performance attention backends outside the text encoder path. Keep CI consistency gates honest by using official GT only for cases whose request semantics are currently comparable.

建议精读以下部分:

  • Gemma3 注意力掩码和 GQA 处理方式的变更(gemma_3.py
  • NumPy 双精度 RoPE 频率计算的实现(ltx_2.py / ltx_2_connector.py
  • res2s 标量精度对齐策略(ltx_2_denoising.py
  • 组件级注意力后端自动配置(server_args.py

这些变更体现了将非标准注意力路径与官方逐位对齐的典型方法,值得扩散模型开发者参考。

讨论亮点

gemini-code-assist[bot] 的 auto-review 指出关键风险:in-place tensor 修改可能破坏全局 sigma 调度或引起调用方副作用,并建议利用原生 SDPA 支持优化 GQA 而非手动扩展。PR 中已通过将 SDPA 范围限定在 text_encoder、采用 scalar 精度函数而非 in-place 修改等方式缓解了部分风险。

实现拆解

  1. Gemma3 text encoder 注意力对齐:在 python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py 中,将 attention mask 从 additive bf16 mask 改为 bool keep-mask,并显式 repeat K/V 处理 GQA 而非依赖 enable_gqa=True;RoPE 计算从预计算 buffer 改为设备端实时生成,以匹配 LTX 方式。

  2. RoPE 频率计算对齐:在 python/sglang/multimodal_gen/runtime/models/dits/ltx_2.pypython/sglang/multimodal_gen/runtime/models/adapter/ltx_2_connector.py 中新增 _ltx2_rope_freq_grid_np_ltx2_connector_rope_freq_grid_np 函数,利用 NumPy float64 生成双精度频率网格,并使用 functools.lru_cache 缓存结果;在 double_precision 分支中替换原有 torch 计算路径。

  3. res2s 调度标量精度对齐:在 python/sglang/multimodal_gen/runtime/pipelines_core/stages/ltx_2_denoising.py 中新增 _ltx2_phi_scalar_ltx2_get_res2s_coefficients_scalar_ltx2_res2s_step_size_scalar 函数,以标量精度计算 SDE 系数,避免张量运算中的精度损失;同时调整 _ltx2_get_sde_coeff 中的 NaN 处理逻辑。

  4. 自动设置 text_encoder 后端:在 python/sglang/multimodal_gen/runtime/server_args.py 中,当 pipeline 为 LTX2 且后端非 DIFFUSERS 时,强制将 component_attention_backends['text_encoder'] 设为 torch_sdpa,并记录日志。

  5. CI 工作流与 GT 管理:在 .github/workflows/diffusion-ci-gt-gen.yml 中扩展官方 GT 生成组,新增 ltx 组覆盖更多 case,更新 ci-data 引用支持指定分支,并调整 sparse checkout 以包含 repro 脚本;在 consistency_threshold.json 中调整 LTX-2.0 SSIM 阈值至 0.89。

文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/pipelines_core/stages/ltx_2_denoising.py 去噪阶段 modified 8.85
python/sglang/multimodal_gen/runtime/models/adapter/ltx_2_connector.py 适配器 modified 7.6
python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py DiT 主干 modified 7.47
python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py 文本编码器 modified 7.16
python/sglang/multimodal_gen/runtime/server_args.py 启动配置 modified 6.12
.github/workflows/diffusion-ci-gt-gen.yml CI 工作流 modified 5.9
python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py 管道 modified 5.71
python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py 去噪阶段 modified 5.66
python/sglang/multimodal_gen/test/test_utils.py 测试工具 modified 5.53
python/sglang/multimodal_gen/configs/sample/ltx_2.py 配置样本 modified 4.58
python/sglang/multimodal_gen/test/server/consistency_threshold.json 一致性阈值 modified 4.45

关键符号

_ltx2_phi_scalar _ltx2_get_res2s_coefficients_scalar _ltx2_res2s_step_size_scalar _ltx2_connector_rope_freq_grid_np _ltx2_rope_freq_grid_np Gemma3Attention.rotary_emb Gemma3Attention.forward

关键源码片段

python/sglang/multimodal_gen/runtime/pipelines_core/stages/ltx_2_denoising.py core-logic

核心 denoising stage,新增标量精度的 res2s 系数计算函数,是数值对齐的关键。

# ltx_2_denoising.py 新增标量精度辅助函数import math@staticmethod
def _ltx2_phi_scalar(j: int, neg_h: float) -> float:
    # 计算标量版本的 phi 函数,避免张量运算中的舍入误差
    if abs(neg_h) < 1e-10:
        return 1.0 / math.factorial(j)
    remainder = sum(neg_h**k / math.factorial(k) for k in range(j))
    return (math.exp(neg_h) - remainder) / (neg_h**j)@classmethod
def _ltx2_get_res2s_coefficients_scalar(
    cls, h: float, c2: float = 0.5
) -> tuple[float, float, float]:
    # 标量版本的 res2s 系数计算,与官方实现一致
    a21 = c2 * cls._ltx2_phi_scalar(1, -h * c2)
    b2 = cls._ltx2_phi_scalar(2, -h) / c2
    b1 = cls._ltx2_phi_scalar(1, -h) - b2
    return a21, b1, b2@staticmethod
def _ltx2_res2s_step_size_scalar(
    sigma: torch.Tensor, sigma_next: torch.Tensor
) -> float:
    # 从张量中提取标量步长,保持高精度
    return float(
        (
            -torch.log(
                sigma_next.detach().double().cpu() / sigma.detach().double().cpu()
            )
        ).item()
    )
python/sglang/multimodal_gen/runtime/models/adapter/ltx_2_connector.py data-contract

涉及 connector 中 RoPE 频率计算的精度提升,新增缓存函数。

# ltx_2_connector.py 新增双精度 RoPE 频率计算
import functools
import numpy as np@functools.lru_cache(maxsize=5)
def _ltx2_connector_rope_freq_grid_np(
    theta: float, num_pos_dims: int, dim: int
) -> torch.Tensor:
    # Official LTX uses NumPy float64 for double-precision RoPE frequencies.
    n_elem = 2 * num_pos_dims
    pow_indices = np.power(
        theta,
        np.linspace(0.0, 1.0, dim // n_elem, dtype=np.float64),
    )
    return torch.tensor(pow_indices * math.pi / 2.0, dtype=torch.float32)

评论区精华

In-place tensor operations risk 正确性

gemini-code-assist[bot] 指出 in-place 修改可能污染全局 sigma 调度或引起副作用

结论:开发者通过 scalar 函数和 NaN 处理避免 in-place,保留了原始路径 · 已解决

SDPA GQA handling 性能

建议利用原生 SDPA GQA 支持,减少手动 repeat 开销

结论:开发者选择显式 repeat 以匹配官方行为,牺牲些许性能确保语义对齐 · 已解决

风险与影响

  1. Gemma3 注意力路径变更风险:在 gemma_3.py 中,attention mask 从 additive 改为 bool、GQA 从 enable_gqa 改为显式 repeat,虽然与官方对齐,但可能影响其他非 LTX 场景下的 Gemma3 编码器行为。由于该文件是 LTX-2 专有编码器,风险可控。
  2. RoPE 精度变更影响_ltx2_rope_freq_grid_np 等函数改用 NumPy float64,在 double_precision 模式下可能引入微小数值差异,但已通过缓存和类型转换确保一致性。
  3. res2s 调度精度变更:新增的标量函数仅用于 HQ 等特定 case,且保留了原始张量路径,不影响主流。
  4. CI 工作流配置错误风险diffusion-ci-gt-gen.yml 中新增的 ltx 组和 ci_data_ref 参数可能因权限或路径问题导致 GT 生成失败,但已有 fallback 机制(GET fallback for HEAD check)。

用户影响:对使用 LTX-2.0/LTX-2.3 模型的最终用户,text encoder 输出和生成质量更接近官方,但无 breaking change;性能方面,torch_sdpa 仅用于 text_encoder,DiT 和 connector 仍可保持高性能后端。
系统影响:CI 一致性测试更可靠,官方 GT 覆盖更多 case,阈值更严格,减少了误报。
团队影响:为后续扩散模型精度对齐工作提供了可复用的模式(NumPy 双精度 RoPE 缓存、scalar 精度辅助函数)。

核心路径变更 数值精度敏感 CI 配置复杂 多文件耦合

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论