Prhub

#20562 Use torch.addmm instead of separate mm and add_ calls for LoRA torch.native

原始 PR 作者 satyamk7054 合并时间 2026-03-27 05:35 文件变更 2 提交数 3 评论 9 代码增减 +20 / -9

执行摘要

LoRA torch-native 后端融合 mm+add_ 为 addmm

当 LoRA 数量较少(4-8 个)且输入为长序列(prefill/embedding)时,torch-native 后端性能优于 csgmv。PR body 指出通过融合 mm 与缩放/累加操作可以进一步提升 torch-native 后端的性能。

值得精读。该 PR 展示了典型的 PyTorch 算子融合优化模式(将多次 CUDA 内核调用合并为一次 addmm),是性能调优的经典案例。同时,.item() 使用和 CPU 张量预分配的做法值得借鉴。

讨论亮点

无 review 评论,但提交历史显示了一次额外提交(style: use .item() for addmm alpha scalar),表明作者在实现后意识到将 scaling_tensor[lora_idx] 转换为 Python 标量(通过 .item())可以避免潜在的张量到张量的运算开销,体现了性能优化的细致考量。

实现拆解

  1. 新增 CPU 端缩放张量:在 TorchNativeLoRABatchInfo 中添加 scalings_cpu 字段(CPU 设备),并在 prepare_lora_batch 中赋值。
  2. sgemm_lora_a_fwd 融合:将 torch.mm(x_seq, w_seq.T)scaling_tensor[lora_idx] * result 合并为 torch.addmm(out_slice, x_seq, w_seq.T, beta=0, alpha=scaling_tensor[lora_idx].item(), out=out_slice)
  3. sgemm_lora_b_fwd 融合:将 torch.mm(x_slice, w_slice.T)output[...].add_(...) 合并为 torch.addmm(out_slice, x_slice, w_slice.T, beta=1, alpha=1, out=out_slice)
  4. 调用点更新:所有调用 sgemm_lora_a_fwd 的地方(lora_a、qkv_lora、gate_up_lora)将 scaling_tensor 参数改为 scalings_cpu
文件 模块 状态 重要度
python/sglang/srt/lora/backend/torch_backend.py LoRA 后端 modified 5.57
python/sglang/srt/lora/torch_ops/lora_ops.py LoRA 算子 modified 4.02

关键符号

sgemm_lora_a_fwd sgemm_lora_b_fwd prepare_lora_batch TorchNativeLoRABatchInfo

关键源码片段

python/sglang/srt/lora/backend/torch_backend.py core-logic

修改了 LoRA batch info 数据结构,新增 `scalings_cpu` 字段,并更新所有调用 `sgemm_lora_a_fwd` 的地方以传递新的 CPU 缩放张量。

# python/sglang/srt/lora/backend/torch_backend.py
# 在 TorchNativeLoRABatchInfo 中新增 CPU 侧缩放张量,避免 GPU-CPU 隐式同步
@dataclass
class TorchNativeLoRABatchInfo(LoRABatchInfo):
    lora_ranks_cpu: Optional[torch.Tensor] = None
    seg_indptr_cpu: Optional[torch.Tensor] = None
    seg_lens_cpu: Optional[torch.Tensor] = None
    weight_indices_cpu: Optional[torch.Tensor] = None
    # 新增 : 缩放因子张量,预先放置到 CPU 设备
    scalings_cpu: Optional[torch.Tensor] = None
​
    # 在 prepare_lora_batch 中赋值
    # ... ( 原有代码 ) ...
    batch_info.scalings_cpu = scalings_tensor
    self.batch_info = batch_info
python/sglang/srt/lora/torch_ops/lora_ops.py infrastructure

修改了 `sgemm_lora_a_fwd` 和 `sgemm_lora_b_fwd` 两个核心函数,将独立的 `torch.mm` + 缩放 / 累加操作替换为单个 `torch.addmm` 调用,实现算子融合。

# python/sglang/srt/lora/torch_ops/lora_ops.py
# sgemm_lora_a_fwd: 将 mm + 缩放融合为一次 addmm 调用
if rank > 0:
    x_seq = inputs[token_offset : token_offset + seq_len, :]
    w_seq = weights[lora_idx, : num_slices * rank, :]
    out_slice = output[token_offset : token_offset + seq_len, : num_slices * rank]
    # 使用 beta=0 表示不累加,alpha 为缩放因子
    torch.addmm(out_slice, x_seq, w_seq.T, beta=0, alpha=scaling_tensor[lora_idx].item(), out=out_slice)
    token_offset += seq_len# sgemm_lora_b_fwd: 将 mm + add_ 融合为一次 addmm 调用 (beta=1 表示保留原有输出 )
out_slice = output[token_offset : token_offset + seq_len, slice_start_output:slice_end_output]
torch.addmm(out_slice, x_slice, w_slice.T, beta=1, alpha=1, out=out_slice)

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

低风险。变更集中在 torch-native 后端,该后端在 LoRA 较少时使用,不影响其他后端(如 Triton csgmv)。所有张量形状和数据类型保持不变,仅将两次运算合并为一次,语义等价。需确认 .item() 在 GPU 张量上的调用是否触发同步,但 scalings_cpu 已在 CPU 上,因此 .item() 无同步开销。

对用户:提升 torch-native 后端的推理性能(~4.4% RPS/TPS),影响范围限于使用该后端的 LoRA 部署场景。对系统:无额外运行时开销,仅增加一个 CPU 张量副本(非常小)。对团队:代码简洁度提升,易于维护。

低风险,语义等价

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论