执行摘要
- 一句话:新增 LTX2 融合 Triton 内核,BF16 路径性能提升约 1%
- 推荐动作:值得阅读,尤其是学习如何在现有 PyTorch 操作中安全插入融合内核的策略。该 PR 展示了保证数值一致性的方法(匹配 BF16 舍入顺序)和条件回退设计,可作为后续类似优化的参考。
功能与动机
LTX2 模型的 DiT 路径中使用了一种特殊的拆分旋转嵌入(split rotary embedding)布局,之前没有针对该布局的融合内核,使用 PyTorch 多次 reshape/broadcast/elementwise 操作存在性能开销。本 PR 通过添加专用 Triton 内核,在保持数值一致性的前提下提升性能。
实现拆解
-
新增 Triton 内核:在 python/sglang/jit_kernel/diffusion/triton/ltx2_rotary.py 中实现 _ltx2_split_rotary_kernel。该内核采用 2D 网格([batch*seq_len, num_heads]),内部按照 x 被拆分为两个半头的布局加载 x_first 和 x_second,加载对应的 cos 和 sin,计算旋转公式时严格匹配原始 PyTorch 中 x*cos 先转为 BF16 再参与 addcmul_ 的顺序,保证位精确。
-
包装函数:编写 apply_ltx2_split_rotary_emb,进行形状校验后分配输出张量,设置 BLOCK_HALF 为 triton.next_power_of_2(half_dim),启动内核(num_warps=1)。
-
路由条件:在 python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py 的 apply_split_rotary_emb 函数入口插入条件判断:当 x 为 3D、cos/sin 为 4D、数据类型均为 BF16,且所有张量在 CUDA 上且连续时,转到新内核;否则回退到原有的 PyTorch 实现。导入语句放在条件内以避免全局 import 开销。
-
测试与验证:未新增显式测试文件,但通过远程内核检查验证了两种目标形状的位精确性(max_abs=0.0),且端到端 benchmark 显示了可测量的性能提升。
关键文件:
python/sglang/jit_kernel/diffusion/triton/ltx2_rotary.py(模块 JIT 内核;类别 source;类型 dependency-wiring;符号 _ltx2_split_rotary_kernel, apply_ltx2_split_rotary_emb): 本 PR 核心新增的 Triton 融合内核文件,实现了 LTX2 拆分旋转嵌入的高效计算,包含内核函数和包装入口,是性能提升的关键。
python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py(模块 模型路由;类别 source;类型 data-contract;符号 apply_split_rotary_emb): 修改现有 apply_split_rotary_emb 函数,添加条件路由到新内核。这是将融合内核接入模型推理的关键入口。
关键符号:_ltx2_split_rotary_kernel, apply_ltx2_split_rotary_emb, apply_split_rotary_emb
关键源码片段
python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py
修改现有 apply_split_rotary_emb 函数,添加条件路由到新内核。这是将融合内核接入模型推理的关键入口。
def apply_split_rotary_emb(
x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]
) -> torch.Tensor:
cos, sin = freqs
# 当输入满足特定条件时,使用 LTX2 专用融合内核
if (
x.ndim == 3
and cos.ndim == 4
and sin.ndim == 4
and x.dtype == torch.bfloat16
and cos.dtype == torch.bfloat16
and sin.dtype == torch.bfloat16
and x.is_cuda
and x.is_contiguous()
and cos.is_cuda
and sin.is_cuda
):
# 延迟导入以避免全局依赖
from sglang.jit_kernel.diffusion.triton.ltx2_rotary import (
apply_ltx2_split_rotary_emb,
)
return apply_ltx2_split_rotary_emb(x, cos, sin)
# 原有 PyTorch 实现(作为 fallback)
x_dtype = x.dtype
needs_reshape = False
if x.ndim != 4 and cos.ndim == 4:
b = x.shape[0]
_, h, t, _ = cos.shape
x = x.reshape(b, t, h, -1).swapaxes(1, 2)
needs_reshape = True
last = x.shape[-1]
if last % 2 != 0:
raise ValueError(
f"Expected x.shape[-1] to be even for split rotary, got {last}."
)
r = last // 2
split_x = x.reshape(*x.shape[:-1], 2, r)
first_x = split_x[..., :1, :]
second_x = split_x[..., 1:, :]
cos_u = cos.unsqueeze(-2)
sin_u = sin.unsqueeze(-2)
out = split_x * cos_u
first_out = out[..., :1, :]
second_out = out[..., 1:, :]
first_out.addcmul_(-sin_u, second_x)
second_out.addcmul_(sin_u, first_x)
out = out.reshape(*out.shape[:-2], last)
if needs_reshape:
out = out.swapaxes(1, 2).reshape(b, t, -1)
return out
评论区精华
PR 未触发人工 review 评论,仅包含自动化 bot 的配额提醒和作者触发的 CI 重新运行指令。无争议或设计讨论。
风险与影响
- 风险:风险较低。新增内核仅替换特定形状和精度下的计算,其余路径完全不变;数值上通过位精确验证;性能提升幅度小且只影响 LTX2 推理阶段。主要风险是新内核在极端形状或非标准 head_dim 下可能未充分测试,但由于条件路由严格,不会错误影响其他情况。
- 影响:对用户:LTX2 模型推理性能轻微提升(约1%),无功能或接口变化。对系统:新增一个约90行的内核文件,对现有代码影响仅限于一个函数内的条件分支。对团队:需要维护新增的 Triton 内核。
- 风险标记:缺少测试覆盖, 精度对齐要求
关联脉络
参与讨论