Prhub

#20564 fix: torch-native LoRA for multi-adapter case

原始 PR 作者 satyamk7054 合并时间 2026-03-27 05:34 文件变更 2 提交数 6 评论 6 代码增减 +19 / -14

执行摘要

修复 torch-native LoRA 多适配器 tensor 尺寸不匹配

修复当 batch 中部分连续请求共享同一 adapter 时触发的 RuntimeError:prepare_lora_batchtorch.unique_consecutive 去重 weight indices 后,仍使用 batch_size 而非去重后的唯一数作为 weight_indices 拷贝长度和 num_segments,导致 tensor 尺寸不匹配。PR body 中给出了完整调用栈。

值得合并,修复明确且风险低。建议开发者关注其他 LoRA 后端(如 torch_trtllm)是否存在类似 batch_size vs segment 数的假设。

讨论亮点

无 review 评论。仅有 CI 触发指令和自动机器人摘要注释。

实现拆解

1. 核心修复:改用去重后的 segment 数

python/sglang/srt/lora/backend/torch_backend.pyprepare_lora_batch 方法中,新增 num_segments = len(weight_indices_tensor) 变量,替代原来所有使用 forward_batch.batch_size 的地方。

weight_indices_tensor 已由 torch.unique_consecutive 处理,长度即去重后的 segment 数,因此 num_segments 总小于等于 batch_size

2. 测试用例覆盖 multi-adapter 合并路径

test/manual/lora/test_torch_backend.py 中,修改 weight_indices = [0, 0, 1] 模拟 3 个请求中前两个连续共享 adapter 0,第三个使用 adapter 1,使 torch.unique_consecutive 合并为 2 个 segment,触发修复路径。

3. 同步调整测试参数

  • batch_size 从 2 改为 3
  • seq_lens[1,1] 改为 [1,1,1]
  • input_ids 形状调整为 (3,1)seq_lens_sum 改为 3
  • test_run_qkv_loraoutput_offset[0,3,6,9,12](4 slices) 改为 [0,3,6,9](3 slices),匹配后端硬编码的 num_slices=3,避免形状不匹配
文件 模块 状态 重要度
python/sglang/srt/lora/backend/torch_backend.py LoRA 后端 modified 5.31
test/manual/lora/test_torch_backend.py 测试 modified 4.56

关键符号

prepare_lora_batch

关键源码片段

python/sglang/srt/lora/backend/torch_backend.py core-logic

核心修复文件:`prepare_lora_batch` 方法中改用 `num_segments` 变量替代 `batch_size`。

# python/sglang/srt/lora/backend/torch_backend.py
# prepare_lora_batch 方法中的关键修复片段def prepare_lora_batch(self, forward_batch, weight_indices, lora_ranks, scalings, use_cuda_graph=False):
    # ... 前面代码通过 torch.unique_consecutive 得到 unique_weight_indices_tensor ...
    weight_indices_tensor = unique_weight_indices_tensor.pin_memory()
​
    bs = forward_batch.batch_size
    num_segments = len(weight_indices_tensor) # 实际去重后的 segment 数,可能小于 bs
​
    if use_cuda_graph:
        batch_info = self.cuda_graph_batch_info
        batch_info.bs = forward_batch.batch_size
        batch_info.num_segments = num_segments # 修复:使用实际 segment 数
    else:
        max_len = max(seg_lens_cpu)
        batch_info = TorchNativeLoRABatchInfo(
            bs=forward_batch.batch_size,
            num_segments=num_segments, # 修复
            max_len=max_len,
            use_cuda_graph=False,
            # ... 其他字段 ...
        )
​
    # 异步拷贝到设备
    batch_info.weight_indices[:num_segments].copy_( # 修复:用 num_segments 而非 bs
        weight_indices_tensor, non_blocking=True
    )
    # ... 后续拷贝 seg_indptr, seg_lens 等 ...
test/manual/lora/test_torch_backend.py test-coverage

测试覆盖:更新测试用例以覆盖 `weight_indices` 去重后 segment 数小于 `batch_size` 的场景。

# test/manual/lora/test_torch_backend.py
# 测试类配置片段,用于覆盖 multi-adapter 合并场景class TestTorchNativeLoRABackend(CustomTestCase):
    device = "cpu"
    # 3 个请求,前两个共享 adapter 0,第三个用 adapter 1
    weight_indices = [0, 0, 1]
    lora_ranks = [1, 1]
    scalings = [1.0, 0.5]
    seq_lens = [1, 1, 1]
    use_cuda_graph = False
​
    forward_batch = ForwardBatch(
        forward_mode=ForwardMode.EXTEND,
        batch_size=3, # batch_size 为 3
        input_ids=torch.tensor([[1], [2], [3]], dtype=torch.int32),
        extend_seq_lens_cpu=seq_lens,
        # ... 其他参数 ...
    )

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

低风险。修复逻辑明确:将 batch_size 替换为实际 segment 数 num_segments,该变量总是小于等于 batch_size,不会导致越界。weight_indices[:num_segments].copy_(weight_indices_tensor) 中 tensor 长度匹配,不会溢出。CUDA graph 路径中 batch_info.num_segments 也同步更新,保持一致性。主要风险在于若其他位置隐式依赖 num_segments == batch_size 的假设,但当前代码中无此依赖。

直接影响 TorchNativeLoRABackend(非默认后端)在 multi-adapter 场景下的正确性。修复后多请求共享 adapter 不再触发 RuntimeError。对其他后端(Triton、FlashInfer)无影响。测试覆盖了该场景,CI 通过。

核心路径变更(LoRA batch 准备) 缺少更多后端验证

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论