Prhub

#24411 [diffusion] Fuse LTX2 split rotary embedding

原始 PR 作者 mickqian 合并时间 2026-05-05 16:07 文件变更 2 提交数 2 评论 3 代码增减 +108 / -0

执行摘要

新增 LTX2 融合 Triton 内核,BF16 路径性能提升约 1%

LTX2 模型的 DiT 路径中使用了一种特殊的拆分旋转嵌入(split rotary embedding)布局,之前没有针对该布局的融合内核,使用 PyTorch 多次 reshape/broadcast/elementwise 操作存在性能开销。本 PR 通过添加专用 Triton 内核,在保持数值一致性的前提下提升性能。

值得阅读,尤其是学习如何在现有 PyTorch 操作中安全插入融合内核的策略。该 PR 展示了保证数值一致性的方法(匹配 BF16 舍入顺序)和条件回退设计,可作为后续类似优化的参考。

讨论亮点

PR 未触发人工 review 评论,仅包含自动化 bot 的配额提醒和作者触发的 CI 重新运行指令。无争议或设计讨论。

实现拆解

  1. 新增 Triton 内核:在 python/sglang/jit_kernel/diffusion/triton/ltx2_rotary.py 中实现 _ltx2_split_rotary_kernel。该内核采用 2D 网格([batch*seq_len, num_heads]),内部按照 x 被拆分为两个半头的布局加载 x_firstx_second,加载对应的 cossin,计算旋转公式时严格匹配原始 PyTorch 中 x*cos 先转为 BF16 再参与 addcmul_ 的顺序,保证位精确。

  2. 包装函数:编写 apply_ltx2_split_rotary_emb,进行形状校验后分配输出张量,设置 BLOCK_HALFtriton.next_power_of_2(half_dim),启动内核(num_warps=1)。

  3. 路由条件:在 python/sglang/multimodal_gen/runtime/models/dits/ltx_2.pyapply_split_rotary_emb 函数入口插入条件判断:当 x 为 3D、cos/sin 为 4D、数据类型均为 BF16,且所有张量在 CUDA 上且连续时,转到新内核;否则回退到原有的 PyTorch 实现。导入语句放在条件内以避免全局 import 开销。

  4. 测试与验证:未新增显式测试文件,但通过远程内核检查验证了两种目标形状的位精确性(max_abs=0.0),且端到端 benchmark 显示了可测量的性能提升。

文件 模块 状态 重要度
python/sglang/jit_kernel/diffusion/triton/ltx2_rotary.py JIT 内核 added 8.09
python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py 模型路由 modified 6.48

关键符号

_ltx2_split_rotary_kernel apply_ltx2_split_rotary_emb apply_split_rotary_emb

关键源码片段

python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py data-contract

修改现有 apply_split_rotary_emb 函数,添加条件路由到新内核。这是将融合内核接入模型推理的关键入口。

def apply_split_rotary_emb(
    x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]
) -> torch.Tensor:
    cos, sin = freqs
​
    # 当输入满足特定条件时,使用 LTX2 专用融合内核
    if (
        x.ndim == 3
        and cos.ndim == 4
        and sin.ndim == 4
        and x.dtype == torch.bfloat16
        and cos.dtype == torch.bfloat16
        and sin.dtype == torch.bfloat16
        and x.is_cuda
        and x.is_contiguous()
        and cos.is_cuda
        and sin.is_cuda
    ):
        # 延迟导入以避免全局依赖
        from sglang.jit_kernel.diffusion.triton.ltx2_rotary import (
            apply_ltx2_split_rotary_emb,
        )
​
        return apply_ltx2_split_rotary_emb(x, cos, sin)
​
    # 原有 PyTorch 实现(作为 fallback)
    x_dtype = x.dtype
    needs_reshape = False
    if x.ndim != 4 and cos.ndim == 4:
        b = x.shape[0]
        _, h, t, _ = cos.shape
        x = x.reshape(b, t, h, -1).swapaxes(1, 2)
        needs_reshape = True
​
    last = x.shape[-1]
    if last % 2 != 0:
        raise ValueError(
            f"Expected x.shape[-1] to be even for split rotary, got {last}."
        )
    r = last // 2
​
    split_x = x.reshape(*x.shape[:-1], 2, r)
    first_x = split_x[..., :1, :]
    second_x = split_x[..., 1:, :]
​
    cos_u = cos.unsqueeze(-2)
    sin_u = sin.unsqueeze(-2)
​
    out = split_x * cos_u
    first_out = out[..., :1, :]
    second_out = out[..., 1:, :]
    first_out.addcmul_(-sin_u, second_x)
    second_out.addcmul_(sin_u, first_x)
​
    out = out.reshape(*out.shape[:-2], last)
    if needs_reshape:
        out = out.swapaxes(1, 2).reshape(b, t, -1)
​
    return out

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

风险较低。新增内核仅替换特定形状和精度下的计算,其余路径完全不变;数值上通过位精确验证;性能提升幅度小且只影响 LTX2 推理阶段。主要风险是新内核在极端形状或非标准 head_dim 下可能未充分测试,但由于条件路由严格,不会错误影响其他情况。

对用户:LTX2 模型推理性能轻微提升(约1%),无功能或接口变化。对系统:新增一个约90行的内核文件,对现有代码影响仅限于一个函数内的条件分支。对团队:需要维护新增的 Triton 内核。

缺少测试覆盖 精度对齐要求

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论