执行摘要
- 一句话:LoRA torch-native 后端融合 mm+add_ 为 addmm
- 推荐动作:值得精读。该 PR 展示了典型的 PyTorch 算子融合优化模式(将多次 CUDA 内核调用合并为一次
addmm),是性能调优的经典案例。同时,.item() 使用和 CPU 张量预分配的做法值得借鉴。
功能与动机
当 LoRA 数量较少(4-8 个)且输入为长序列(prefill/embedding)时,torch-native 后端性能优于 csgmv。PR body 指出通过融合 mm 与缩放/累加操作可以进一步提升 torch-native 后端的性能。
实现拆解
- 新增 CPU 端缩放张量:在
TorchNativeLoRABatchInfo 中添加 scalings_cpu 字段(CPU 设备),并在 prepare_lora_batch 中赋值。
- sgemm_lora_a_fwd 融合:将
torch.mm(x_seq, w_seq.T) 与 scaling_tensor[lora_idx] * result 合并为 torch.addmm(out_slice, x_seq, w_seq.T, beta=0, alpha=scaling_tensor[lora_idx].item(), out=out_slice)。
- sgemm_lora_b_fwd 融合:将
torch.mm(x_slice, w_slice.T) 与 output[...].add_(...) 合并为 torch.addmm(out_slice, x_slice, w_slice.T, beta=1, alpha=1, out=out_slice)。
- 调用点更新:所有调用
sgemm_lora_a_fwd 的地方(lora_a、qkv_lora、gate_up_lora)将 scaling_tensor 参数改为 scalings_cpu。
关键文件:
python/sglang/srt/lora/backend/torch_backend.py(模块 LoRA后端;类别 source;类型 core-logic;符号 TorchNativeLoRABatchInfo, TorchNativeLoRABackend, prepare_lora_batch): 修改了 LoRA batch info 数据结构,新增 scalings_cpu 字段,并更新所有调用 sgemm_lora_a_fwd 的地方以传递新的 CPU 缩放张量。
python/sglang/srt/lora/torch_ops/lora_ops.py(模块 LoRA算子;类别 infra;类型 infrastructure;符号 sgemm_lora_a_fwd, sgemm_lora_b_fwd): 修改了 sgemm_lora_a_fwd 和 sgemm_lora_b_fwd 两个核心函数,将独立的 torch.mm + 缩放/累加操作替换为单个 torch.addmm 调用,实现算子融合。
关键符号:sgemm_lora_a_fwd, sgemm_lora_b_fwd, prepare_lora_batch, TorchNativeLoRABatchInfo
关键源码片段
python/sglang/srt/lora/backend/torch_backend.py
修改了 LoRA batch info 数据结构,新增 scalings_cpu 字段,并更新所有调用 sgemm_lora_a_fwd 的地方以传递新的 CPU 缩放张量。
# python/sglang/srt/lora/backend/torch_backend.py
# 在 TorchNativeLoRABatchInfo 中新增 CPU 侧缩放张量,避免 GPU-CPU 隐式同步
@dataclass
class TorchNativeLoRABatchInfo(LoRABatchInfo):
lora_ranks_cpu: Optional[torch.Tensor] = None
seg_indptr_cpu: Optional[torch.Tensor] = None
seg_lens_cpu: Optional[torch.Tensor] = None
weight_indices_cpu: Optional[torch.Tensor] = None
# 新增 : 缩放因子张量,预先放置到 CPU 设备
scalings_cpu: Optional[torch.Tensor] = None
# 在 prepare_lora_batch 中赋值
# ... ( 原有代码 ) ...
batch_info.scalings_cpu = scalings_tensor
self.batch_info = batch_info
python/sglang/srt/lora/torch_ops/lora_ops.py
修改了 sgemm_lora_a_fwd 和 sgemm_lora_b_fwd 两个核心函数,将独立的 torch.mm + 缩放/累加操作替换为单个 torch.addmm 调用,实现算子融合。
# python/sglang/srt/lora/torch_ops/lora_ops.py
# sgemm_lora_a_fwd: 将 mm + 缩放融合为一次 addmm 调用
if rank > 0:
x_seq = inputs[token_offset : token_offset + seq_len, :]
w_seq = weights[lora_idx, : num_slices * rank, :]
out_slice = output[token_offset : token_offset + seq_len, : num_slices * rank]
# 使用 beta=0 表示不累加,alpha 为缩放因子
torch.addmm(out_slice, x_seq, w_seq.T, beta=0, alpha=scaling_tensor[lora_idx].item(), out=out_slice)
token_offset += seq_len
# sgemm_lora_b_fwd: 将 mm + add_ 融合为一次 addmm 调用 (beta=1 表示保留原有输出 )
out_slice = output[token_offset : token_offset + seq_len, slice_start_output:slice_end_output]
torch.addmm(out_slice, x_slice, w_slice.T, beta=1, alpha=1, out=out_slice)
评论区精华
无 review 评论,但提交历史显示了一次额外提交(style: use .item() for addmm alpha scalar),表明作者在实现后意识到将 scaling_tensor[lora_idx] 转换为 Python 标量(通过 .item())可以避免潜在的张量到张量的运算开销,体现了性能优化的细致考量。
风险与影响
- 风险:低风险。变更集中在 torch-native 后端,该后端在 LoRA 较少时使用,不影响其他后端(如 Triton csgmv)。所有张量形状和数据类型保持不变,仅将两次运算合并为一次,语义等价。需确认
.item() 在 GPU 张量上的调用是否触发同步,但 scalings_cpu 已在 CPU 上,因此 .item() 无同步开销。
- 影响:对用户:提升 torch-native 后端的推理性能(~4.4% RPS/TPS),影响范围限于使用该后端的 LoRA 部署场景。对系统:无额外运行时开销,仅增加一个 CPU 张量副本(非常小)。对团队:代码简洁度提升,易于维护。
- 风险标记:低风险,语义等价
关联脉络
- PR #23649 [diffusion] support LoRA for LTX2.3: 同为 LoRA 相关变更,涉及 diffusion 领域的 LoRA 支持,但无直接代码依赖。
参与讨论