Prhub

#24160 [lora] Share MoE LoRA Info

原始 PR 作者 erikwijmans 合并时间 2026-05-29 10:01 文件变更 9 提交数 1 评论 1 代码增减 +310 / -46

执行摘要

共享 MoE LoRA batch 信息减少重复计算

当前 MoE LoRA 设计在每一层都重复计算仅依赖 batch 的信息,造成了不必要的开销。PR body 指出 'a significant amount of information is computed every layer despite information depending solely on just the batch info.'

建议对 weight_indices 可能为 -1 的情况进行防御性处理(如掩码后 scatter),并增加相应测试。在非 CUDA 平台上验证 kernel 兼容性。整体设计良好,值得精读。

讨论亮点

只有一个来自 gemini-code-assist[bot] 的 review 评论,指出 _compute_moe_lora_info 实现中如果 weight_indices 包含 -1(表示无 LoRA 的 token),则 scatter_ 会因负索引崩溃,lora_ranks[weight_indices.long()] 也会错误索引。该问题被标记为 high priority,但最终合并的代码中未显式处理 -1,可能需要调用方保证 weight_indices 不含 -1 或后续修复。

实现拆解

  1. 新增 MoELoRABatchInfo 数据结构python/sglang/srt/lora/utils.py):定义集中存放 MoE LoRA 所需 per-batch 元数据的 dataclass,包括 seg_indptr、req_to_lora、adapter_enabled、token_lora_mapping。
  2. 集中计算 MoE LoRA 信息python/sglang/srt/lora/backend/base_backend.py):新增 _add_moe_lora_info 方法,由各后端在 prepare_lora_batch 末尾调用;该方法内部调用 _compute_moe_lora_info 函数(通过 @torch.compile 或 Triton kernel)一次性生成所有元数据,并写入 batch_info.moe_lora_info
  3. Triton kernel 实现base_backend.py):新增 _compute_moe_lora_info_kernel,高效计算 token_lora_mapping(通过 repeat_interleave 方式)和 adapter_enabled(基于 lora_ranks 权重),并断言 kernel 覆盖全部 token。
  4. 消费侧简化python/sglang/srt/lora/layers.pylora_moe_runners.py):FusedMoEWithLoRA._get_lora_info 直接使用 batch_info.moe_lora_info 填充 LoRAInfo 结构,移除原有的逐层 adapter_enabled 计算和 _compute_token_lora_mapping 函数。
  5. 后端集成:在 ascend_backend、chunked_backend、torch_backend、triton_backend 的 prepare_lora_batch 末尾添加 _add_moe_lora_info 调用,并在 FusedMoEWithLoRA 构造时设置 lora_backend.is_moe_lora = True 以启用该路径。
  6. 单元测试test/registered/lora/test_moe_lora_info.py):测试 _compute_moe_lora_info 在不同 segment 长度、预分配缓冲区 vs 新分配、以及 kernel 覆盖不足错误抛出等场景下的正确性。
文件 模块 状态 重要度
python/sglang/srt/lora/backend/base_backend.py 后端核心 modified 8.74
python/sglang/srt/lora/utils.py 数据定义 modified 6.8
python/sglang/srt/lora/layers.py 层包装 modified 6.65
python/sglang/srt/lora/lora_moe_runners.py MoE 运行 modified 6.35
test/registered/lora/test_moe_lora_info.py 测试 added 7.18

关键符号

is_moe_lora _add_moe_lora_info _compute_moe_lora_info_kernel _compute_moe_lora_info MoELoRABatchInfo _expected_adapter_enabled test_compute_moe_lora_info_expands_segments test_compute_moe_lora_info_rejects_undercovered_launch _compute_token_lora_mapping _get_lora_info

关键源码片段

python/sglang/srt/lora/utils.py core-logic

新增 MoELoRABatchInfo 数据类,并在 LoRABatchInfo 中增加 moe_lora_info 字段,是共享信息的容器。

@dataclass
class MoELoRABatchInfo:
    # Per-request segment indptrs used by MoE LoRA routing, shape (bs + 1,).
    seg_indptr: torch.Tensor
​
    # Per-request adapter index used by MoE LoRA routing, shape (bs,).
    req_to_lora: torch.Tensor
​
    # A mask indicating if lora adapter is enabled. Shape (num_loras,)
    adapter_enabled: torch.Tensor
​
    # A mapping of which lora adapter is used for each token. Shape (num_tokens,)
    # If a token has no lora adapter, the value is -1.
    token_lora_mapping: torch.Tensor
​
​
@dataclass
class LoRABatchInfo:
    # ... existing fields ...
    # MoE LoRA batch info
    moe_lora_info: Optional[MoELoRABatchInfo] = None
test/registered/lora/test_moe_lora_info.py test-coverage

新增测试文件,覆盖新 kernel 在不同 segment 长度、预分配缓冲区、不足覆盖等场景下的正确性。

import sysimport pytest
import torchfrom sglang.srt.lora.backend.base_backend import _compute_moe_lora_info
from sglang.test.ci.ci_register import register_cuda_ciregister_cuda_ci(est_time=5, stage="base-b", runner_config="1-gpu-small")
​
​
def _expected_adapter_enabled(
    lora_ranks: torch.Tensor,
    weight_indices: torch.Tensor,
) -> torch.Tensor:
    # 使用 scatter 构建预期的 adapter_enabled 掩码
    expected = torch.zeros_like(lora_ranks)
    expected.scatter_(
        0,
        weight_indices.long(),
        (lora_ranks[weight_indices.long()] > 0).to(torch.int32),
    )
    return expected
​
​
@pytest.mark.parametrize("use_preallocated_buffers", [False, True])
def test_compute_moe_lora_info_expands_segments(use_preallocated_buffers: bool):
    # 测试 kernel 在不同 segment 长度下能否正确展开 token_lora_mapping 和 adapter_enabled
    device = "cuda"
    seg_lens = torch.tensor([5, 1, 7, 3, 9, 2], dtype=torch.int32, device=device)
    seg_indptr = torch.zeros((seg_lens.numel() + 1,), dtype=torch.int32, device=device)
    seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
​
    weight_indices = torch.tensor([2, 0, 5, 2, 3, 7], dtype=torch.int32, device=device)
    lora_ranks = torch.tensor(
        [0, 12, 16, 32, 24, 8, 0, 4], dtype=torch.int32, device=device
    )
    num_tokens = int(seg_indptr[-1].item())
​
    if use_preallocated_buffers:
        adapter_enabled = torch.full_like(lora_ranks, 123)
        token_lora_mapping = torch.full(
            (num_tokens + 11,), 456, dtype=torch.int32, device=device
        )
    else:
        adapter_enabled = None
        token_lora_mapping = None
​
    actual_enabled, actual_mapping = _compute_moe_lora_info(
        num_tokens,
        seg_indptr,
        lora_ranks,
        weight_indices,
        adapter_enabled,
        token_lora_mapping,
        max_len=int(seg_lens.max().item()),
    )
    torch.cuda.synchronize()
​
    expected_mapping = torch.repeat_interleave(weight_indices, seg_lens)
    expected_enabled = _expected_adapter_enabled(lora_ranks, weight_indices)
​
    torch.testing.assert_close(actual_mapping, expected_mapping)
    torch.testing.assert_close(actual_enabled, expected_enabled)
​
    if use_preallocated_buffers:
        # 验证使用了预分配缓冲区而非新分配
        assert actual_mapping.data_ptr() == token_lora_mapping.data_ptr()
​
​
def test_compute_moe_lora_info_rejects_undercovered_launch():
    # 测试 kernel 无法覆盖所有 token 时是否抛出 AssertionError
    device = "cuda"
    seg_indptr = torch.tensor([0, 300], dtype=torch.int32, device=device)
    weight_indices = torch.tensor([0], dtype=torch.int32, device=device)
    lora_ranks = torch.tensor([16], dtype=torch.int32, device=device)
​
    with pytest.raises(AssertionError, match="under-covers tokens"):
        _compute_moe_lora_info(
            300,
            seg_indptr,
            lora_ranks,
            weight_indices,
            None,
            None,
            max_len=1,
        )
​
​
if __name__ == "__main__":
    sys.exit(pytest.main([__file__]))

评论区精华

weight_indices 为 -1 时的处理 正确性

gemini-code-assist[bot] 指出在 _compute_moe_lora_info 中如果 weight_indices 包含 -1(表示无 LoRA 的 token),scatter_ 会因负索引崩溃,lora_ranks[weight_indices.long()] 也会错误访问。建议在 scatter 前屏蔽 -1。

结论:未在最终代码中显式修复,但可能依赖调用方保证 weight_indices 不含 -1。 · unresolved

风险与影响

  1. 索引越界风险:如果 weight_indices 中出现 -1_compute_moe_lora_info_kernel 中的 scatter_lora_ranks 索引会导致崩溃或错误结果。当前测试和主要调用路径中 weight_indices 均为非负值,但缺少对副作用的保护。
  2. 新 kernel 平台兼容性:Triton kernel _compute_moe_lora_info_kernel 依赖于 CUDA,在非 CUDA 平台(如 NPU)上可能无法运行。但 ascend_backend 使用了不同的实现路径(调用 _add_moe_lora_info 但实际执行的是 _compute_moe_lora_info 的纯 PyTorch 版本?从代码看 ascend_backend 也调用了 _add_moe_lora_info,而后者当前 CUDA-only,仍需验证)。
  3. 断言失败风险:kernel 内部断言要求覆盖所有 token,若 max_len 过小或 segment 长度异常可能导致推理中断。
  4. 内存开销增加:预计算的 token_lora_mappingadapter_enabled 需要额外显存,与 batch size 成正比。

影响范围仅限于使用 MoE LoRA 的模型(如 Kimi K2.5、DeepSeek 等)。对于非 MoE LoRA 的场景无影响。正面性能收益显著:单层 MoE 减少 20% 耗时,端到端 decode 提升约 12%。对开发人员的影响:需要理解 MoELoRABatchInfo 数据结构,并确保在添加新 LoRA 后端时调用 _add_moe_lora_info。测试覆盖良好。

索引越界风险 (weight_indices -1) 新 kernel 平台兼容性

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论