执行摘要
- 一句话:修复 torch-native LoRA 多适配器 tensor 尺寸不匹配
- 推荐动作:值得合并,修复明确且风险低。建议开发者关注其他 LoRA 后端(如 torch_trtllm)是否存在类似
batch_size vs segment 数的假设。
功能与动机
修复当 batch 中部分连续请求共享同一 adapter 时触发的 RuntimeError:prepare_lora_batch 用 torch.unique_consecutive 去重 weight indices 后,仍使用 batch_size 而非去重后的唯一数作为 weight_indices 拷贝长度和 num_segments,导致 tensor 尺寸不匹配。PR body 中给出了完整调用栈。
实现拆解
1. 核心修复:改用去重后的 segment 数
在 python/sglang/srt/lora/backend/torch_backend.py 的 prepare_lora_batch 方法中,新增 num_segments = len(weight_indices_tensor) 变量,替代原来所有使用 forward_batch.batch_size 的地方。
weight_indices_tensor 已由 torch.unique_consecutive 处理,长度即去重后的 segment 数,因此 num_segments 总小于等于 batch_size。
2. 测试用例覆盖 multi-adapter 合并路径
在 test/manual/lora/test_torch_backend.py 中,修改 weight_indices = [0, 0, 1] 模拟 3 个请求中前两个连续共享 adapter 0,第三个使用 adapter 1,使 torch.unique_consecutive 合并为 2 个 segment,触发修复路径。
3. 同步调整测试参数
batch_size 从 2 改为 3
seq_lens 从 [1,1] 改为 [1,1,1]
input_ids 形状调整为 (3,1),seq_lens_sum 改为 3
test_run_qkv_lora 中 output_offset 从 [0,3,6,9,12](4 slices) 改为 [0,3,6,9](3 slices),匹配后端硬编码的 num_slices=3,避免形状不匹配
关键文件:
python/sglang/srt/lora/backend/torch_backend.py(模块 LoRA 后端;类别 source;类型 core-logic;符号 prepare_lora_batch): 核心修复文件:prepare_lora_batch 方法中改用 num_segments 变量替代 batch_size。
test/manual/lora/test_torch_backend.py(模块 测试;类别 test;类型 test-coverage;符号 TestTorchNativeLoRABackend): 测试覆盖:更新测试用例以覆盖 weight_indices 去重后 segment 数小于 batch_size 的场景。
关键符号:prepare_lora_batch
关键源码片段
python/sglang/srt/lora/backend/torch_backend.py
核心修复文件:prepare_lora_batch 方法中改用 num_segments 变量替代 batch_size。
# python/sglang/srt/lora/backend/torch_backend.py
# prepare_lora_batch 方法中的关键修复片段
def prepare_lora_batch(self, forward_batch, weight_indices, lora_ranks, scalings, use_cuda_graph=False):
# ... 前面代码通过 torch.unique_consecutive 得到 unique_weight_indices_tensor ...
weight_indices_tensor = unique_weight_indices_tensor.pin_memory()
bs = forward_batch.batch_size
num_segments = len(weight_indices_tensor) # 实际去重后的 segment 数,可能小于 bs
if use_cuda_graph:
batch_info = self.cuda_graph_batch_info
batch_info.bs = forward_batch.batch_size
batch_info.num_segments = num_segments # 修复:使用实际 segment 数
else:
max_len = max(seg_lens_cpu)
batch_info = TorchNativeLoRABatchInfo(
bs=forward_batch.batch_size,
num_segments=num_segments, # 修复
max_len=max_len,
use_cuda_graph=False,
# ... 其他字段 ...
)
# 异步拷贝到设备
batch_info.weight_indices[:num_segments].copy_( # 修复:用 num_segments 而非 bs
weight_indices_tensor, non_blocking=True
)
# ... 后续拷贝 seg_indptr, seg_lens 等 ...
test/manual/lora/test_torch_backend.py
测试覆盖:更新测试用例以覆盖 weight_indices 去重后 segment 数小于 batch_size 的场景。
# test/manual/lora/test_torch_backend.py
# 测试类配置片段,用于覆盖 multi-adapter 合并场景
class TestTorchNativeLoRABackend(CustomTestCase):
device = "cpu"
# 3 个请求,前两个共享 adapter 0,第三个用 adapter 1
weight_indices = [0, 0, 1]
lora_ranks = [1, 1]
scalings = [1.0, 0.5]
seq_lens = [1, 1, 1]
use_cuda_graph = False
forward_batch = ForwardBatch(
forward_mode=ForwardMode.EXTEND,
batch_size=3, # batch_size 为 3
input_ids=torch.tensor([[1], [2], [3]], dtype=torch.int32),
extend_seq_lens_cpu=seq_lens,
# ... 其他参数 ...
)
评论区精华
无 review 评论。仅有 CI 触发指令和自动机器人摘要注释。
风险与影响
- 风险:低风险。修复逻辑明确:将
batch_size 替换为实际 segment 数 num_segments,该变量总是小于等于 batch_size,不会导致越界。weight_indices[:num_segments].copy_(weight_indices_tensor) 中 tensor 长度匹配,不会溢出。CUDA graph 路径中 batch_info.num_segments 也同步更新,保持一致性。主要风险在于若其他位置隐式依赖 num_segments == batch_size 的假设,但当前代码中无此依赖。
- 影响:直接影响 TorchNativeLoRABackend(非默认后端)在 multi-adapter 场景下的正确性。修复后多请求共享 adapter 不再触发 RuntimeError。对其他后端(Triton、FlashInfer)无影响。测试覆盖了该场景,CI 通过。
- 风险标记:核心路径变更(LoRA batch 准备), 缺少更多后端验证
关联脉络
- PR #23649 [diffusion] support LoRA for LTX2.3: 同一仓库中另一涉及 LoRA 支持的 PR,但功能不同,可作为跨模块 LoRA 改进的参考。
参与讨论