Prhub

#23738 fix(lora): avoid CUDA graph-breaking scalar assignment in seg_indptr

原始 PR 作者 yushengsu-thu 合并时间 2026-04-30 16:11 文件变更 1 提交数 1 评论 0 代码增减 +1 / -1

执行摘要

修复 LoRA CUDA graph 中 seg_indptr 标量赋值导致的同步点

在 CUDA graph 录制过程中,seg_indptr[0] = 0 这种标量赋值操作会触发 host-to-device 拷贝,导致同步点(sync point)产生,使得 CUDA graph 无法正常捕获。PR 描述中明确指出该赋值会“break CUDA graph capture”,并附带了性能对比截图说明改进效果。

值得快速合并,但建议在后续 PR 中补充 CUDA graph 兼容性测试,避免类似回归。

讨论亮点

无实质讨论。审核人 jybsuper 和 hnyls2002 均批准,自动化工具给出了正向确认。

实现拆解

  1. 定位问题代码:在 compute_sgemm_routing 方法中,当 use_cuda_graph 为 True 时,第 204 行使用 sgemm.seg_indptr[0] = 0seg_indptr 的第一个元素赋零,该标量操作会引入 CPU-GPU 同步。
  2. 替换为 on-device 操作:将赋值改为 sgemm.seg_indptr[0:1].zero_(),通过切片索引配合 .zero_() 原地置零,所有操作在 GPU 上完成,不触发同步。
  3. 保持其余逻辑不变:后续的 torch.cumsum 仍在设备上执行,整个 use_cuda_graph 分支现在完全符合 CUDA graph 录制要求。
文件 模块 状态 重要度
python/sglang/srt/lora/backend/triton_backend.py LoRA 后端 modified 4.72

关键符号

compute_sgemm_routing

关键源码片段

python/sglang/srt/lora/backend/triton_backend.py core-logic

单行变更加在 `compute_sgemm_routing` 方法中,修复 CUDA graph 路径的同步点问题,是本次变更的唯一文件。

def compute_sgemm_routing(self, use_cuda_graph: bool):
    """Sort tokens by adapter and build merged segments for sgemm LoRA."""
    bi = self.batch_info
    bs = bi.bs
    mlpb = self.max_loras_per_batch
    wi = bi.weight_indices[:bs]
​
    perm = torch.argsort(wi, stable=True).to(torch.int32)
    sorted_wi = wi[perm]
    adapter_ids = torch.arange(mlpb, device=wi.device, dtype=torch.int32)
    seg_starts = torch.searchsorted(sorted_wi, adapter_ids)
    seg_ends = torch.searchsorted(sorted_wi, adapter_ids, right=True)
    seg_lens = seg_ends - seg_starts
​
    if use_cuda_graph:
        sgemm = getattr(self, "cuda_graph_sgemm_batch_info", None)
        if sgemm is None:
            return
        sgemm.permutation[:bs] = perm
        sgemm.seg_lens[:] = seg_lens
        # 原写法 sgemm.seg_indptr[0] = 0 会触发 CPU-GPU 同步,破坏 CUDA graph 录制
        sgemm.seg_indptr[0:1].zero_() # 切片 .zero_() 完全在设备上执行
        torch.cumsum(sgemm.seg_lens, dim=0, out=sgemm.seg_indptr[1:])
        sgemm.max_len = bs
        sgemm.lora_ranks[:mlpb] = bi.lora_ranks[:mlpb]
        sgemm.scalings[:mlpb] = bi.scalings[:mlpb]
    else:
        seg_indptr = torch.zeros(mlpb + 1, dtype=torch.int32, device=wi.device)
        seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
        sgemm = LoRABatchInfo(
            bs=mlpb,
            use_cuda_graph=False,
            num_segments=mlpb,
            seg_lens=seg_lens,
            seg_indptr=seg_indptr,
            max_len=bs,
            weight_indices=adapter_ids,
            lora_ranks=bi.lora_ranks[:mlpb].clone(),
            scalings=bi.scalings[:mlpb].clone(),
            permutation=perm,
        )
​
    self.sgemm_batch_info = sgemm

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

该改动仅影响 CUDA graph 路径中的一行代码,语义等价(对 seg_indptr[0] 置零),不会改变计算结果。风险极低,但缺少对应的测试验证(如确保 CUDA graph 录制后输出一致),建议后续添加。

影响范围仅限启用 CUDA graph 的 LoRA 推理场景,修复后可正常录制 graph,消除因同步点导致的性能退化。对未使用 CUDA graph 的路径无影响。

缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论