Prhub

#42554 [PD][Nixl] Mamba prefix caching mode support

原始 PR 作者 NickLucche 合并时间 2026-06-04 21:41 文件变更 3 提交数 6 评论 4 代码增减 +97 / -6

执行摘要

PD Nixl 连接器支持 Mamba 前缀缓存模式

在 PD(Prefill-Decode)Mamba 配置中启用前缀缓存时,原代码会在 _apply_prefix_caching 中触发断言失败(SSM 组要求本地与远程块数相等,但前缀缓存会插入占位块导致不等)。此 PR 通过识别并裁剪 SSM 占位块来消除断言,实现对前缀缓存命中的正确处理。

值得精读,了解分布式前缀缓存在 Mamba 模型上的实现模式。但建议关注边缘情况的处理,考虑后续修复断言和切片逻辑。

讨论亮点

gemini-code-assist[bot] 在 review 中指出两个高风险问题:

1) 断言 num_local_blocks == 1 在完全缓存命中(本地块数为 0)时会失败;
2) 切片 remote_group[-num_local_blocks:]num_local_blocks=0 时返回全部列表而非空列表,导致块计数不匹配。这两个建议均未被采纳,PR 已合并,风险遗留。

实现拆解

  1. Nixl 连接器逻辑增强:在 vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py 中修改 _apply_prefix_caching 方法。当检测到 SSM 组且远程块数多于本地时,对远程进行尾部裁剪以移除占位块;当 FA 组且 physical_blocks_per_logical_kv_block 相同时,对远程进行尾部裁剪以处理局部前缀缓存命中;否则保留原有的块数匹配逻辑。
  2. 调度器适配:在 vllm/v1/core/sched/scheduler.py 中修改两处:移除 _mamba_block_aligned_split 中对 num_external_computed_tokens 必须为零的断言,允许外部 KV 连接器传递已计算的 token 数;在 schedule 方法中调用 _mamba_block_aligned_split 时增加 not load_kv_async 条件,避免在异步加载 KV 时进行块对齐分割。
  3. 测试覆盖:在 tests/v1/kv_connector/unit/test_nixl_connector_hma.py 中新增 test_apply_prefix_caching_ssm_prefix_cache_hit,通过参数化测试验证 SSM 仅裁剪、FA 部分命中、以及两者混合三种场景。
文件 模块 状态 重要度
vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py KV 连接器 modified 6.55
vllm/v1/core/sched/scheduler.py 调度器 modified 5.8
tests/v1/kv_connector/unit/test_nixl_connector_hma.py 连接器测试 modified 6.44

关键符号

_apply_prefix_caching _mamba_block_aligned_split test_apply_prefix_caching_ssm_prefix_cache_hit

关键源码片段

vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py core-logic

核心修改:为 Mamba 混合模型添加前缀缓存处理逻辑,修改 _apply_prefix_caching 方法以支持 SSM 占位块裁剪和 FA 部分前缀缓存命中。

def _apply_prefix_caching(self, local_block_ids, remote_block_ids, remote_physical_per_logical):
    # 部分前缀缓存命中 : 只读取未计算的块。
    # 跳过 mamba 组——它们的块代表完整状态(conv+ssm),
    # 而不是按 token 的数据,所以裁剪会破坏传输。
    remote_block_ids = list(remote_block_ids)
    if not self._has_mamba:
        for i, remote_group in enumerate(remote_block_ids):
            num_local_blocks = len(local_block_ids[i])
            assert num_local_blocks <= len(remote_group)
            if num_local_blocks < len(remote_group):
                remote_block_ids[i] = remote_group[-num_local_blocks:]
    else:
        local_block_ids = list(local_block_ids)
        for i, remote_group in enumerate(remote_block_ids):
            num_local_blocks = len(local_block_ids[i])
            num_remote_blocks = len(remote_group)
​
            # SSM 前缀缓存处理:远程块中最后一块是实际状态,前面是占位块
            if (
                _is_ssm_spec(self._group_spec_types[i])
                and num_local_blocks < num_remote_blocks
            ):
                assert num_local_blocks == 1, "SSM 只能有一个本地块"
                remote_block_ids[i] = remote_group[-num_local_blocks:] # 保留最后一个实际块
​
            # FA 部分前缀缓存命中(仅当 block_size 匹配时)
            elif (
                self._physical_blocks_per_logical_kv_block
                == remote_physical_per_logical
                and num_local_blocks < num_remote_blocks
            ):
                remote_block_ids[i] = remote_group[-num_local_blocks:]
​
            # 其他情况:保持原有对齐逻辑(异或异构 block_size)
            else:
                max_padding = max(
                    self._physical_blocks_per_logical_kv_block,
                    remote_physical_per_logical,
                )
                assert abs(num_local_blocks - num_remote_blocks) < max_padding, (
                    f"Group {i}: |{num_local_blocks} - "
                    f"{num_remote_blocks}| >= {max_padding}"
                )
                num_blocks = min(num_local_blocks, num_remote_blocks)
                local_block_ids[i] = local_block_ids[i][:num_blocks]
                remote_block_ids[i] = remote_group[:num_blocks]
​
    return local_block_ids, remote_block_ids
tests/v1/kv_connector/unit/test_nixl_connector_hma.py test-coverage

新增测试用例 test_apply_prefix_caching_ssm_prefix_cache_hit,覆盖三种前缀缓存场景(SSM 单独、FA 单独、混合)。

@pytest.mark.cpu_test
@pytest.mark.parametrize(
    "local_physical_per_logical,remote_physical_per_logical,"
    "local_block_ids,remote_block_ids,"
    "expected_local,expected_remote",
    [
        # SSM 前缀缓存:远程有 3 个占位块 + 1 个实际块
        # 本地只有 1 个实际块
        pytest.param(
            10, 10,
            [list(range(10)), [42]],
            [list(range(10)), [40, 41, 42, 43]],
            [list(range(10)), [42]],
            [list(range(10)), [43]],
            id="ssm_prefix_trim_only",
        ),
        # FA 部分前缀缓存命中(同构 TP):本地 4 块已缓存,远程 10 块
        pytest.param(
            10, 10,
            [list(range(6, 10)), [42]],
            [list(range(10)), [42]],
            [list(range(6, 10)), [42]],
            [list(range(6, 10)), [42]],
            id="fa_prefix_hit_homo_tp",
        ),
        # 混合场景:FA 部分命中 + SSM 占位块裁剪
        pytest.param(
            10, 10,
            [[6, 7, 8, 9], [99]],
            [list(range(10)), [10, 20, 99]],
            [[6, 7, 8, 9], [99]],
            [[6, 7, 8, 9], [99]],
            id="fa_prefix_hit_and_ssm_trim",
        ),
    ],
)
def test_apply_prefix_caching_ssm_prefix_cache_hit(
    local_physical_per_logical, remote_physical_per_logical,
    local_block_ids, remote_block_ids,
    expected_local, expected_remote,
):
    from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
        NixlConnectorWorker,
    )
    from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
​
    worker = object.__new__(NixlConnectorWorker)
    worker._has_mamba = True
    worker._physical_blocks_per_logical_kv_block = local_physical_per_logical
    worker._group_spec_types = (FullAttentionSpec, MambaSpec)
    worker.kv_cache_config = make_kv_cache_config(block_size=16, mamba_enabled=True)
​
    aligned_local, aligned_remote = worker._apply_prefix_caching(
        local_block_ids, remote_block_ids, remote_physical_per_logical
    )
​
    assert aligned_local == expected_local, (
        f"Expected local {expected_local}, got {aligned_local}"
    )
    assert aligned_remote == expected_remote, (
        f"Expected remote {expected_remote}, got {aligned_remote}"
    )

评论区精华

本地块数为零时的断言和切片正确性 正确性

gemini-code-assist[bot] 指出当完全前缀缓存命中时,SSM 本地块数可能为 0,断言 num_local_blocks==1 会失败;且切片 remote_group[-0:] 返回全部列表而非空列表,导致块计数不匹配。

结论:建议未被采纳,PR 已合并,风险遗留。 · 待处理

风险与影响

1) 断言 num_local_blocks == 1 过于严格:当请求完全命中前缀缓存时,SSM 本地块数可能为 0,触发 AssertionError。
2) 切片 remote_group[-num_local_blocks:]num_local_blocks=0 时行为不符合预期(返回整个列表),可能导致后续断言失败。
3) 当前测试用例未覆盖 num_local_blocks=0 的场景。

正面影响:允许 Mamba 模型在 PD 配置下使用前缀缓存,提升缓存命中率和推理效率。负面影响:上述风险可能导致特定边缘场景(如完全缓存命中)下运行时错误,影响稳定性。影响范围限于使用 Mamba 混合模型且启用前缀缓存的 PD 设置。

零边界情况未处理 断言可能过于严格 切片在 num_local_blocks=0 时行为异常

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论