执行摘要
- 一句话:修复LTX-2/2.3扩散模型多GPU精度对齐与HQ两阶段路径
- 推荐动作:建议所有使用LTX-2/2.3模型进行推理和CI测试的团队仔细阅读此PR。其中关于CFG引导分支合并、Attention Backend选择、RoPE精度控制的决策值得在其它扩散模型推理框架中借鉴。
功能与动机
根据PR描述,主要动机是提升LTX-2 / LTX-2.3扩散模型与官方输出的精度对齐,尤其是多GPU CI用例和LTX-2.3 HQ两阶段路径。之前的LTX-2.3多GPU CI设置对某些用例的一致性阈值过松,且使用了较慢或对齐性较差的并行模式;HQ路径在PR#23366的SpongeBob复现中出现了回归。
实现拆解
- 重构CFG引导分支合并逻辑:引入
_ltx2_combine_guided_x0_parallel_av 方法替代旧的 _ltx2_combine_guided_x0_parallel,将视频和音频引导分支先分别计算为x0后再合并,通过一次all-reduce同步所有分支,避免折叠系数导致的bf16舍入差异。
- 修复RoPE频率生成:移除基于CPU/NumPy的缓存函数
_ltx2_rope_freq_grid_np,改为在目标设备上使用指定精度(float32/float64)直接生成频率,保留与官方一致的舍入轨迹;Gemma3编码器也改为使用预计算的 cos_sin_cache 并通过 index_select 查找,同时修复滑动注意力模式的检测逻辑(支持 sliding_window_pattern)。
- 统一Attention Backend:让
transformer_2 等次级组件继承基础组件的attention backend,允许在 LocalAttention 和 USPAttention 中使用cuDNN SDP后端(通过 allow_cudnn_sdp 参数),以匹配官方LTX的 torch_sdpa 行为;同时处理旧版FA3 varlen kernel不支持 out= 关键字的情况。
- 对齐LTX-2.3 HQ路径:恢复阶段1和阶段2的res2s噪声精度(保持float64轨迹),修复sigma空间数学计算,并确保HQ变长序列的CFG广播源使用全局rank。
- 收紧CI测试:调整多GPU用例的并行策略(TP/CFG Parallel取代SP/Ulysses),收紧一致性阈值(clip/SSIM/PSNR/mean_abs_diff),添加新用例到官方一致性GT集合,并改进测试失败时的HTML报告,包含生成图像链接。
关键文件:
python/sglang/multimodal_gen/runtime/pipelines_core/stages/ltx_2_denoising.py(模块 去噪流水线;类别 source;类型 core-logic;符号 _ltx2_combine_guided_x0_parallel, _ltx2_combine_guided_x0_parallel_av, _move_ltx2_scheduler_tensors_to_device): 核心逻辑变更:重构CFG引导分支合并,移除旧方法,引入新方法_ltx2_combine_guided_x0_parallel_av,支持视频/音频分离预处理后再合并,改变舍入路径
python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py(模块 文本编码器;类别 source;类型 data-contract;符号 rotary_emb, _apply_rotary_pos_emb): Gemma3编码器RoPE和滑动注意力修复:改用预计算缓存,支持sliding_window_pattern,提升精度
python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py(模块 扩散Transformer;类别 source;类型 data-contract;符号 _ltx2_rope_freq_grid_np): LTX DiT模型RoPE频率生成方式变更,移除NumPy缓存,改用设备端生成;attention后端开放cuDNN SDP
python/sglang/multimodal_gen/runtime/layers/attention/layer.py(模块 注意力层;类别 source;类型 dependency-wiring): Attention层增加allow_cudnn_sdp参数,控制sdpa_kernel上下文,使LTX可以使用cuDNN加速
python/sglang/multimodal_gen/runtime/layers/attention/backends/sdpa.py(模块 SDPA后端;类别 source;类型 dependency-wiring): SDPA后端同样增加allow_cudnn_sdp参数,使官方LTX的torch_sdpa路径可选择cuDNN后端
python/sglang/multimodal_gen/test/test_utils.py(模块 测试工具;类别 test;类型 test-coverage;符号 _save_generated_artifact_images): 测试工具增强:添加_save_generated_artifact_images函数保存生成帧图像,并在HTML报告中添加链接,便于调试
python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py(模块 去噪流水线;类别 source;类型 core-logic): 伴随ltx_2_denoising.py的接口调整,更新配置键
python/sglang/multimodal_gen/runtime/server_args.py(模块 服务参数;类别 source;类型 core-logic): 量化配置优先级调整:显式--quantization优先于检查点metadata
python/sglang/multimodal_gen/test/server/perf_baselines.json(模块 性能基线;类别 test;类型 test-coverage): 更新多GPU LTX性能基线,匹配新的并行策略和精度对齐
python/sglang/multimodal_gen/runtime/pipelines_core/executors/parallel_executor.py(模块 并行执行器;类别 source;类型 core-logic): 修复CFG并行阶段的广播源:使用全局source rank而非本地rank
docs_new/src/snippets/diffusion/ltx-deployment.jsx(模块 部署文档;类别 source;类型 core-logic): 文档示例更新,反映新的部署配置
python/sglang/multimodal_gen/test/server/consistency_threshold.json(模块 一致性阈值;类别 test;类型 test-coverage): 收紧多个LTX用例的一致性阈值,提升CI质量门禁
关键符号:_ltx2_combine_guided_x0_parallel_av, _move_ltx2_scheduler_tensors_to_device, _apply_rotary_pos_emb, _ltx2_rope_freq_grid_np, apply_split_rotary_emb, _save_generated_artifact_images
关键源码片段
python/sglang/multimodal_gen/runtime/pipelines_core/stages/ltx_2_denoising.py
核心逻辑变更:重构CFG引导分支合并,移除旧方法,引入新方法_ltx2_combine_guided_x0_parallel_av,支持视频/音频分离预处理后再合并,改变舍入路径
# python/sglang/multimodal_gen/runtime/pipelines_core/stages/ltx_2_denoising.py
@classmethod
def _ltx2_combine_guided_x0_parallel_av(
cls,
*,
video_latents: torch.Tensor,
audio_latents: torch.Tensor,
local_video_velocities: dict[str, torch.Tensor],
local_audio_velocities: dict[str, torch.Tensor],
video_sigma: float | torch.Tensor,
audio_sigma: float | torch.Tensor,
video_cfg_scale: float,
video_stg_scale: float,
video_rescale_scale: float,
video_modality_scale: float,
audio_cfg_scale: float,
audio_stg_scale: float,
audio_rescale_scale: float,
audio_modality_scale: float,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
在CFG并行中,跨rank重建完整的引导分支,然后分别计算视频和音频的guided x0。
不再合并系数后all-reduce,而是先将每个分支的x0通过all-reduce同步,
再使用官方标准的组合公式,以消除bf16下系数折叠导致的数值漂移。
"""
# 获取第一个 velocity 来构造模板张量(用于 zero 填充)
first_video_velocity = next(iter(local_video_velocities.values()))
first_audio_velocity = next(iter(local_audio_velocities.values()))
video_template = cls._ltx2_velocity_to_x0(
video_latents, first_video_velocity, video_sigma
)
audio_template = cls._ltx2_velocity_to_x0(
audio_latents, first_audio_velocity, audio_sigma
)
video_numel = video_template.numel()
# 对 4 个分支(cond, neg, perturbed, modality)分别收集并 all-reduce
branches: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
for name in ("cond", "neg", "perturbed", "modality"):
if name in local_video_velocities:
# 该 rank 拥有此分支的真实输出
local_video = cls._ltx2_velocity_to_x0(
video_latents, local_video_velocities[name], video_sigma
)
local_audio = cls._ltx2_velocity_to_x0(
audio_latents, local_audio_velocities[name], audio_sigma
)
else:
# 该 rank 不负责此分支,用零填充
local_video = torch.zeros_like(video_template)
local_audio = torch.zeros_like(audio_template)
# 拼接后 all-reduce,使每个 rank 都获得完整的分支 x0
flat = torch.cat((local_video.reshape(-1), local_audio.reshape(-1)))
flat = cfg_model_parallel_all_reduce(flat)
branches[name] = (
flat[:video_numel].reshape_as(video_template),
flat[video_numel:].reshape_as(audio_template),
)
# 分别计算视频和音频的 guided x0(使用官方组合公式,不折叠系数)
guided_video = cls._ltx2_calculate_guided_x0(
cond=branches["cond"][0],
uncond_text=branches["neg"][0],
uncond_perturbed=branches["perturbed"][0],
uncond_modality=branches["modality"][0],
cfg_scale=video_cfg_scale,
stg_scale=video_stg_scale,
rescale_scale=video_rescale_scale,
modality_scale=video_modality_scale,
)
guided_audio = cls._ltx2_calculate_guided_x0(
cond=branches["cond"][1],
uncond_text=branches["neg"][1],
uncond_perturbed=branches["perturbed"][1],
uncond_modality=branches["modality"][1],
cfg_scale=audio_cfg_scale,
stg_scale=audio_stg_scale,
rescale_scale=audio_rescale_scale,
modality_scale=audio_modality_scale,
)
return guided_video, guided_audio
python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py
Gemma3编码器RoPE和滑动注意力修复:改用预计算缓存,支持sliding_window_pattern,提升精度
# python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py
# 在 __init__ 中:
# 之前的代码直接从 layer_types 列表索引,现在兼容 sliding_window_pattern 配置
sliding_window_pattern = getattr(
config.text_config, "sliding_window_pattern", None
)
self.is_sliding = (
bool((layer_id + 1) % sliding_window_pattern)
if sliding_window_pattern
else False
)
self.layer_type = "sliding_attention" if self.is_sliding else None
# RoPE 初始化时,将 self.rotary_emb 重命名为 self.rotary_pos_emb
self.rotary_pos_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=config.text_config.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
is_neox_style=True,
)
# 新增方法 _apply_rotary_pos_emb,使用预计算 cos_sin_cache,避免每步重新计算 inv_freq
def _apply_rotary_pos_emb(self, positions, q, k):
positions_flat = positions.flatten().to(
device=self.rotary_pos_emb.cos_sin_cache.device, dtype=torch.long
)
cos_sin = self.rotary_pos_emb.cos_sin_cache.index_select(0, positions_flat)
cos, sin = cos_sin.chunk(2, dim=-1)
# 扩展半维度频率以匹配 head_dim(HF Gemma3 风格)
cos = torch.cat((cos, cos), dim=-1).to(device=q.device, dtype=q.dtype)
sin = torch.cat((sin, sin), dim=-1).to(device=q.device, dtype=q.dtype)
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
# 应用旋转
q = q.reshape(num_tokens, -1, self.head_dim)
k = k.reshape(num_tokens, -1, self.head_dim)
q = q * cos + _rotate_half(q) * sin
k = k * cos + _rotate_half(k) * sin
return q, k
python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py
LTX DiT模型RoPE频率生成方式变更,移除NumPy缓存,改用设备端生成;attention后端开放cuDNN SDP
# python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py
# 在 RoPE 频率生成部分(原使用缓存函数 _ltx2_rope_freq_grid_np):
# 新实现:直接在目标设备上生成频率,保留 float64 精度路径
freqs_dtype = torch.float64 if self.double_precision else torch.float32
pow_indices = torch.pow(
self.theta,
torch.linspace(
start=0.0,
end=1.0,
steps=self.dim // num_rope_elems,
dtype=freqs_dtype,
device=device,
),
)
freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)
# ... 后续与 grid 组合
# 在 TransformerBlock 初始化中,允许 cuDNN SDP 后端以匹配官方 LTX 的 torch_sdpa 行为
if use_local:
self.attn = LocalAttention(
# ... 其他参数
allow_cudnn_sdp=True,
)
评论区精华
本PR没有实质性的review讨论,主要作者独立完成调试和验证。PR body中明确排除了transformer fp8-cast,指出该PR的目的路径(CI和HQ精度对齐)无需此变更,为后续独立分支留出空间。
风险与影响
- 风险:
- CFG分支合并重构:新方法改变了all-reduce和系数组合顺序,可能影响单GPU和多GPU结果的数值一致性,已通过官方输出PSNR验证(20.722 vs 20.707)。
- RoPE精度切换:从float64 NumPy改为设备端float32/float64生成,可能改变所有LTX模型的生成轨迹,但已在多GPU CI中验证。
- Attention Backend:默认启用cuDNN SDP可能在非CUDA设备上回退,且未覆盖所有后端(如FlashInfer),需确保FA3 varlen无out关键字的兼容处理。
- CI阈值收紧:新阈值更严格,可能导致后续升级引入波动,但提升了质量保障。
- 影响:对用户:LTX-2/2.3模型的生成结果更稳定、与官方对齐度更高,特别是在多GPU和HQ两阶段场景下。对系统:引入少量性能优化(如CFG Parallel比Ulysses更快),但主要改进是数值精度。对团队:为后续LTX精度优化提供了清晰的调试路径和CI基准,并展示了跨文件精细调优的工程实践。
- 风险标记:核心路径变更, 精度敏感逻辑, 多GPU一致性, CI阈值收紧
关联脉络
- PR #23366 [LTX-2.3] SpongeBob repro for HQ precision: PR#23366是LTX-2.3 HQ精度对齐的原始复现用例,本PR在body中引用了该PR的repro结果来验证HQ路径的修复效果
参与讨论