执行摘要
- 一句话:对齐 LTX-2 与官方实现的注意力语义和数值精度
- 推荐动作:建议精读以下部分:
- Gemma3 注意力掩码和 GQA 处理方式的变更(
gemma_3.py)
- NumPy 双精度 RoPE 频率计算的实现(
ltx_2.py / ltx_2_connector.py)
- res2s 标量精度对齐策略(
ltx_2_denoising.py)
- 组件级注意力后端自动配置(
server_args.py)
这些变更体现了将非标准注意力路径与官方逐位对齐的典型方法,值得扩散模型开发者参考。
功能与动机
Align native LTX text-encoder attention behavior with the official implementation while preserving high-performance attention backends outside the text encoder path. Keep CI consistency gates honest by using official GT only for cases whose request semantics are currently comparable.
实现拆解
-
Gemma3 text encoder 注意力对齐:在 python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py 中,将 attention mask 从 additive bf16 mask 改为 bool keep-mask,并显式 repeat K/V 处理 GQA 而非依赖 enable_gqa=True;RoPE 计算从预计算 buffer 改为设备端实时生成,以匹配 LTX 方式。
-
RoPE 频率计算对齐:在 python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py 和 python/sglang/multimodal_gen/runtime/models/adapter/ltx_2_connector.py 中新增 _ltx2_rope_freq_grid_np 和 _ltx2_connector_rope_freq_grid_np 函数,利用 NumPy float64 生成双精度频率网格,并使用 functools.lru_cache 缓存结果;在 double_precision 分支中替换原有 torch 计算路径。
-
res2s 调度标量精度对齐:在 python/sglang/multimodal_gen/runtime/pipelines_core/stages/ltx_2_denoising.py 中新增 _ltx2_phi_scalar、_ltx2_get_res2s_coefficients_scalar 和 _ltx2_res2s_step_size_scalar 函数,以标量精度计算 SDE 系数,避免张量运算中的精度损失;同时调整 _ltx2_get_sde_coeff 中的 NaN 处理逻辑。
-
自动设置 text_encoder 后端:在 python/sglang/multimodal_gen/runtime/server_args.py 中,当 pipeline 为 LTX2 且后端非 DIFFUSERS 时,强制将 component_attention_backends['text_encoder'] 设为 torch_sdpa,并记录日志。
-
CI 工作流与 GT 管理:在 .github/workflows/diffusion-ci-gt-gen.yml 中扩展官方 GT 生成组,新增 ltx 组覆盖更多 case,更新 ci-data 引用支持指定分支,并调整 sparse checkout 以包含 repro 脚本;在 consistency_threshold.json 中调整 LTX-2.0 SSIM 阈值至 0.89。
关键文件:
python/sglang/multimodal_gen/runtime/pipelines_core/stages/ltx_2_denoising.py(模块 去噪阶段;类别 source;类型 core-logic;符号 _ltx2_phi_scalar, _ltx2_get_res2s_coefficients_scalar, _ltx2_res2s_step_size_scalar): 核心 denoising stage,新增标量精度的 res2s 系数计算函数,是数值对齐的关键。
python/sglang/multimodal_gen/runtime/models/adapter/ltx_2_connector.py(模块 适配器;类别 source;类型 data-contract;符号 _ltx2_connector_rope_freq_grid_np): 涉及 connector 中 RoPE 频率计算的精度提升,新增缓存函数。
python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py(模块 DiT 主干;类别 source;类型 data-contract;符号 _ltx2_rope_freq_grid_np): DiT 主干中的 RoPE 频率计算对齐,与 connector 对称。
python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py(模块 文本编码器;类别 source;类型 data-contract): Gemma3 text encoder 注意力对齐,修改 mask 类型、GQA 处理和 RoPE 计算。
python/sglang/multimodal_gen/runtime/server_args.py(模块 启动配置;类别 source;类型 core-logic): 自动设置 text_encoder 后端为 torch_sdpa,确保对齐生效。
.github/workflows/diffusion-ci-gt-gen.yml(模块 CI 工作流;类别 infra;类型 infrastructure;符号 link): CI 工作流扩展,增加官方 GT 生成覆盖和灵活性。
python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py(模块 管道;类别 source;类型 core-logic): 管道级别调整,适配种子变化和精度对齐。
python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py(模块 去噪阶段;类别 source;类型 core-logic): 音频-视频去噪通用调整,配合精度修改。
python/sglang/multimodal_gen/test/test_utils.py(模块 测试工具;类别 test;类型 test-coverage): 测试工具更新,支持新的 GT 文件和 CI 场景。
python/sglang/multimodal_gen/configs/sample/ltx_2.py(模块 配置样本;类别 source;类型 core-logic): 示例配置更新,增加官方 GT 相关配置。
python/sglang/multimodal_gen/test/server/consistency_threshold.json(模块 一致性阈值;类别 test;类型 test-coverage): 调整 LTX-2.0 SSIM 阈值,平衡 CI 严格性。
关键符号:_ltx2_phi_scalar, _ltx2_get_res2s_coefficients_scalar, _ltx2_res2s_step_size_scalar, _ltx2_connector_rope_freq_grid_np, _ltx2_rope_freq_grid_np, Gemma3Attention.rotary_emb, Gemma3Attention.forward
关键源码片段
python/sglang/multimodal_gen/runtime/pipelines_core/stages/ltx_2_denoising.py
核心 denoising stage,新增标量精度的 res2s 系数计算函数,是数值对齐的关键。
# ltx_2_denoising.py 新增标量精度辅助函数
import math
@staticmethod
def _ltx2_phi_scalar(j: int, neg_h: float) -> float:
# 计算标量版本的 phi 函数,避免张量运算中的舍入误差
if abs(neg_h) < 1e-10:
return 1.0 / math.factorial(j)
remainder = sum(neg_h**k / math.factorial(k) for k in range(j))
return (math.exp(neg_h) - remainder) / (neg_h**j)
@classmethod
def _ltx2_get_res2s_coefficients_scalar(
cls, h: float, c2: float = 0.5
) -> tuple[float, float, float]:
# 标量版本的 res2s 系数计算,与官方实现一致
a21 = c2 * cls._ltx2_phi_scalar(1, -h * c2)
b2 = cls._ltx2_phi_scalar(2, -h) / c2
b1 = cls._ltx2_phi_scalar(1, -h) - b2
return a21, b1, b2
@staticmethod
def _ltx2_res2s_step_size_scalar(
sigma: torch.Tensor, sigma_next: torch.Tensor
) -> float:
# 从张量中提取标量步长,保持高精度
return float(
(
-torch.log(
sigma_next.detach().double().cpu() / sigma.detach().double().cpu()
)
).item()
)
python/sglang/multimodal_gen/runtime/models/adapter/ltx_2_connector.py
涉及 connector 中 RoPE 频率计算的精度提升,新增缓存函数。
# ltx_2_connector.py 新增双精度 RoPE 频率计算
import functools
import numpy as np
@functools.lru_cache(maxsize=5)
def _ltx2_connector_rope_freq_grid_np(
theta: float, num_pos_dims: int, dim: int
) -> torch.Tensor:
# Official LTX uses NumPy float64 for double-precision RoPE frequencies.
n_elem = 2 * num_pos_dims
pow_indices = np.power(
theta,
np.linspace(0.0, 1.0, dim // n_elem, dtype=np.float64),
)
return torch.tensor(pow_indices * math.pi / 2.0, dtype=torch.float32)
评论区精华
gemini-code-assist[bot] 的 auto-review 指出关键风险:in-place tensor 修改可能破坏全局 sigma 调度或引起调用方副作用,并建议利用原生 SDPA 支持优化 GQA 而非手动扩展。PR 中已通过将 SDPA 范围限定在 text_encoder、采用 scalar 精度函数而非 in-place 修改等方式缓解了部分风险。
- In-place tensor operations risk (correctness): 开发者通过 scalar 函数和 NaN 处理避免 in-place,保留了原始路径
- SDPA GQA handling (performance): 开发者选择显式 repeat 以匹配官方行为,牺牲些许性能确保语义对齐
风险与影响
关联脉络
- PR #24320 [Misc] component attention backend override support: PR #24313 的 body 明确提到 merge of #24320,后者提供了组件级别 attention backend 覆盖机制,是实现 text_encoder 单独使用 torch_sdpa 的基础。
- PR #23335 Fix diffusion fallback guards and validation: 同为 diffusion 模块的数值对齐相关工作,涉及 fallback 路径。
- PR #24117 [codex] Optimize Z-Image packed QKV: 同一模块(multimodal_gen)的性能优化 PR,与对齐工作互补。
参与讨论