Prhub

#25277 [UnifiedTree]: Fix UnifiedRadixCache device match semantics with HiCache

原始 PR 作者 hzh0425 合并时间 2026-05-16 00:40 文件变更 7 提交数 5 评论 7 代码增减 +398 / -110

执行摘要

修复 HiCache 下 UnifiedRadixCache 的设备匹配语义

此 PR 阐明 UnifiedRadixCache 在 HiCache 启用时的匹配语义:分离 best_match_node(允许 host-backed 节点)和用作 last_device_node 的设备驻留匹配锚点;追踪 last_device_node 作为所有活跃组件均为设备驻留的最深节点;在 HiCache 下临时禁用 Mamba branching seqlen;更新 init_load_back 测试以验证前缀索引收集、仅 aux load-back 和回退行为。

建议精读,尤其是设计模式:如何在一个匹配过程中同时跟踪 best host match 和 best device match。create_match_validator 的参数化设计值得借鉴。如果团队正在开发缓存层或类似的分层匹配系统,此 PR 提供了清晰的参考实现。

讨论亮点

Review 中 gemini-code-assist[bot] 提出了三个中等优先级的优化建议:

  • 使用 generator 表达式替代列表解析以利用短路(_all_valid)。
  • 修正 _collect_new_prefix_indices 的返回类型提示为 torch.Tensor 而非 Optional[torch.Tensor]
  • 移除 init_load_backif new_indices is None 的死代码。
    合并者 ispobock 直接批准了 PR,作者未针对这些评论修改代码。这些建议属于可选的性能/清理改进,不影响正确性。

实现拆解

  1. _match_prefix_helper 中引入 best_match_device_nodeseparate_device_match 标志,当 HiCache 激活时(cache_controller is not None),同时维护两套 validators(全量 validators 和仅设备 validators);非 HiCache 模式行为不变。
  2. 为所有 create_match_validator 加入 match_device_only 参数,各组件(Full、Mamba、SWA、TreeComponent)各自实现过滤逻辑:仅设备 validators 忽略 host_value。
  3. 删除未使用的 _match_prefix_helper_readonly 方法,其逻辑已合入新版 _match_prefix_helper;更新调用者(test_unified_radix_cache_unittest.py)中的匹配断言。
  4. MambaComponent.finalize_match_result 中当 HiCache 激活时禁用 branching seqlen 填充(self.cache.cache_controller is None 条件),避免在 host-backed 状态下触发未测试路径。
  5. init_load_back 中新增 _collect_new_prefix_indices 辅助函数,从 best_match_node 向上收集前缀索引直到 last_best_match_device_node;简化 success 路径(移除 None 检查死代码)。
  6. 测试方面:新增 test_match_prefix_best_and_device_node_without_hicachetest_hicache_mamba_host_best_match_keeps_device_anchortest_hicache_swa_host_best_match_keeps_device_anchortest_mamba_branching_seqlen_disabled_under_hicache;修改 _load_back_node 返回逻辑使测试可断言加载后的 value 存在。
文件 模块 状态 重要度
python/sglang/srt/mem_cache/unified_radix_cache.py 缓存层 modified 8.74
test/registered/unit/mem_cache/test_unified_radix_cache_unittest.py 单元测试 modified 7.81
python/sglang/srt/mem_cache/unified_cache_components/mamba_component.py 缓存层 modified 6.2
python/sglang/srt/mem_cache/unified_cache_components/full_component.py 缓存层 modified 6.17
python/sglang/srt/mem_cache/unified_cache_components/tree_component.py 缓存层 modified 6.07
python/sglang/srt/mem_cache/unified_cache_components/swa_component.py 缓存层 modified 5.98
test/registered/radix_cache/test_unified_radix_hicache_kl.py 集成测试 modified 2.85

关键符号

_match_prefix_helper _match_prefix_helper_readonly _all_valid _collect_new_prefix_indices create_match_validator finalize_match_result init_load_back

关键源码片段

python/sglang/srt/mem_cache/unified_radix_cache.py core-logic

核心匹配逻辑重构:分离 best_match_node 与 best_match_device_node,删除未使用的 readonly helper,新增 _all_valid 生成器检查,重写 init_load_back 的前缀索引收集

def _match_prefix_helper(
    self, key: RadixKey
) -> tuple[list[torch.Tensor], UnifiedTreeNode, UnifiedTreeNode, int]:
    # 新设计:在 HiCache 模式下分为全量 validators(允许 host-backed 节点)
    # 和设备 validators(仅允许 device 驻留节点),用于分别追踪
    # best_match_node(允许 host)和 best_match_device_node(仅 device)。
    node = self.root_node
    child_key = key.child_key(self.page_size)
    value: list[torch.Tensor] = []
    best_match_node = node
    best_match_device_node = node
    best_match_device_value_len = 0
    separate_device_match = self.cache_controller is not None
​
    if separate_device_match:
        validators = tuple(
            comp.create_match_validator() for comp in self._components_tuple
        )
        device_validators = tuple(
            comp.create_match_validator(match_device_only=True)
            for comp in self._components_tuple
        )
    else:
        # 非 HiCache 模式:全量 validators == 设备 validators
        validators = tuple(
            comp.create_match_validator(match_device_only=True)
            for comp in self._components_tuple
        )
        device_validators = validators
​
    def _all_valid(validators, node):
        # 使用生成器表达式允许短路,避免列表分配
        return all(v(node) for v in validators)
​
    while len(key) > 0 and child_key in node.children:
        child = node.children[child_key]
        # 死节点(evicted 且无 host backup)停止遍历
        if child.evicted and not child.backuped:
            break
        prefix_len = child.key.match(key, page_size=self.page_size)
        if prefix_len < len(child.key):
            # 部分匹配:不 split,停止
            break
        if not child.evicted:
            value.append(child.component_data[BASE_COMPONENT_TYPE].value)
        node = child
        if _all_valid(validators, node):
            best_match_node = node
            best_value_len = len(value) # 对应旧的 best_value_len
        if separate_device_match and _all_valid(device_validators, node):
            best_match_device_node = node
            best_match_device_value_len = len(value)
        # 注意:非 HiCache 下同时更新 best_match_device_node 由调用者负责
        key = key[prefix_len:]
        if len(key):
            child_key = key.child_key(self.page_size)
    return value, best_match_node, best_match_device_node, best_match_device_value_len
python/sglang/srt/mem_cache/unified_cache_components/mamba_component.py core-logic

create_match_validator 加入 match_device_only 参数;finalize_match_result 中禁用 HiCache 下的 branching seqlen

def create_match_validator(
    self, match_device_only: bool = False
) -> Callable[[UnifiedTreeNode], bool]:
    ct = self.component_type
    if match_device_only:
        # 仅设备模式:只检查 device 上的 value 是否存在
        return lambda node: node.component_data[ct].value is not None
​
    # HiCache 模式:evicted + backuped(host_value 存在)也是有效匹配
    return lambda node: (
        node.component_data[ct].value is not None
        or node.component_data[ct].host_value is not None
    )def finalize_match_result(
    self, result, params, value_chunks, best_value_len
) -> MatchResult:
    # HiCache 可以继续使用前缀匹配并加载 host-backed 的 Mamba 状态。
    # 暂时跳过 branching-state 填充,后续可添加 HiCache 感知的分支策略。
    if self.cache.cache_controller is None and len(value_chunks) > best_value_len:
        # 非 HiCache:计算 branching_seqlen
        chunk_size = get_global_server_args().mamba_cache_chunk_size
        aligned_seqlen = (
            sum(len(v) for v in value_chunks) // chunk_size
        ) * chunk_size
        branching_seqlen = aligned_seqlen if aligned_seqlen > 0 else None
    else:
        branching_seqlen = None
    # ... 继续处理 cow_mamba

评论区精华

_all_valid 使用 generator 表达式的建议 性能

gemini-code-assist[bot] 建议将 `_all_valid` 中的列表解析改为生成器表达式,以利用短路并减少内存分配。

结论:作者未修改代码,合并者直接批准,该建议未被采纳。 · 待处理

_collect_new_prefix_indices 返回类型提示修正 正确性

gemini-code-assist[bot] 指出该函数始终返回 `torch.Tensor`,应使用 `torch.Tensor` 而非 `Optional[torch.Tensor]`。

结论:作者未修改函数签名,合并者批准,类型提示可能保持原样。 · 待处理

移除 init_load_back 中 dead code 的建议 设计

gemini-code-assist[bot] 指出 `if new_indices is None` 是死代码,应移除简化逻辑。

结论:作者未删除该分支,但合并者接受当前实现。 · 待处理

风险与影响

1) 核心路径变更_match_prefix_helper 是每次请求前缀匹配的必经函数,重构后语义分离可能引入回归,但新增的单元测试覆盖了主要场景。
2) Mamba branching seqlen 禁用:当 HiCache 激活且匹配到 host-backed 节点时,Mamba 的 branching 状态填充被跳过,可能影响 Mamba 模型的缓存跳跃性能,但该路径在 HiCache 下原本就不稳定,禁用更为安全。
3) 组件 validator 的一致性:所有四个组件(Full、Mamba、SWA、TreeComponent)的 create_match_validator 签名同步修改,若新增组件未实现 match_device_only 将引发 TypeError(抽象方法)。

影响范围:使用 HiCache 且启用了 Mamba 或 SWA 的用户。正确性收益:修复了之前可能出现的 last_device_node 错误导致 prefix indices 包含 host-only 节点的问题;对非 HiCache 用户无行为影响。性能影响:新增的双 validators 遍历在 HiCache 路径上有微小开销,但可通过 generator 优化缓解。

核心路径变更 Mamba 分支禁用 缺少非 HiCache 回归测试(但已有覆盖)

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论