执行摘要
- 一句话:修复 HiCache 下 UnifiedRadixCache 的设备匹配语义
- 推荐动作:建议精读,尤其是设计模式:如何在一个匹配过程中同时跟踪 best host match 和 best device match。
create_match_validator 的参数化设计值得借鉴。如果团队正在开发缓存层或类似的分层匹配系统,此 PR 提供了清晰的参考实现。
功能与动机
此 PR 阐明 UnifiedRadixCache 在 HiCache 启用时的匹配语义:分离 best_match_node(允许 host-backed 节点)和用作 last_device_node 的设备驻留匹配锚点;追踪 last_device_node 作为所有活跃组件均为设备驻留的最深节点;在 HiCache 下临时禁用 Mamba branching seqlen;更新 init_load_back 测试以验证前缀索引收集、仅 aux load-back 和回退行为。
实现拆解
- 在
_match_prefix_helper 中引入 best_match_device_node 和 separate_device_match 标志,当 HiCache 激活时(cache_controller is not None),同时维护两套 validators(全量 validators 和仅设备 validators);非 HiCache 模式行为不变。
- 为所有
create_match_validator 加入 match_device_only 参数,各组件(Full、Mamba、SWA、TreeComponent)各自实现过滤逻辑:仅设备 validators 忽略 host_value。
- 删除未使用的
_match_prefix_helper_readonly 方法,其逻辑已合入新版 _match_prefix_helper;更新调用者(test_unified_radix_cache_unittest.py)中的匹配断言。
- 在
MambaComponent.finalize_match_result 中当 HiCache 激活时禁用 branching seqlen 填充(self.cache.cache_controller is None 条件),避免在 host-backed 状态下触发未测试路径。
- 在
init_load_back 中新增 _collect_new_prefix_indices 辅助函数,从 best_match_node 向上收集前缀索引直到 last_best_match_device_node;简化 success 路径(移除 None 检查死代码)。
- 测试方面:新增
test_match_prefix_best_and_device_node_without_hicache、test_hicache_mamba_host_best_match_keeps_device_anchor、test_hicache_swa_host_best_match_keeps_device_anchor、test_mamba_branching_seqlen_disabled_under_hicache;修改 _load_back_node 返回逻辑使测试可断言加载后的 value 存在。
关键文件:
python/sglang/srt/mem_cache/unified_radix_cache.py(模块 缓存层;类别 source;类型 core-logic;符号 _match_prefix_helper, _match_prefix_helper_readonly, _all_valid, _collect_new_prefix_indices): 核心匹配逻辑重构:分离 best_match_node 与 best_match_device_node,删除未使用的 readonly helper,新增 _all_valid 生成器检查,重写 init_load_back 的前缀索引收集
test/registered/unit/mem_cache/test_unified_radix_cache_unittest.py(模块 单元测试;类别 test;类型 test-coverage;符号 test_match_prefix_best_and_device_node_without_hicache, test_hicache_mamba_host_best_match_keeps_device_anchor, test_hicache_swa_host_best_match_keeps_device_anchor, test_mamba_branching_seqlen_disabled_under_hicache): 新增 3 个 HiCache 锚点测试和 1 个 Mamba branching 禁用测试,调整 init_load_back 测试辅助函数以验证加载后的值存在
python/sglang/srt/mem_cache/unified_cache_components/mamba_component.py(模块 缓存层;类别 source;类型 core-logic;符号 create_match_validator, finalize_match_result): create_match_validator 加入 match_device_only 参数;finalize_match_result 中禁用 HiCache 下的 branching seqlen
python/sglang/srt/mem_cache/unified_cache_components/full_component.py(模块 缓存层;类别 source;类型 core-logic;符号 create_match_validator): create_match_validator 加入 match_device_only 参数,仅设备模式检查 value 不为 None
python/sglang/srt/mem_cache/unified_cache_components/tree_component.py(模块 缓存层;类别 source;类型 core-logic;符号 create_match_validator): 抽象方法签名增加 match_device_only 参数,所有子类必须同步修改
python/sglang/srt/mem_cache/unified_cache_components/swa_component.py(模块 缓存层;类别 source;类型 core-logic;符号 create_match_validator): create_match_validator 加入 match_device_only 参数,仅设备模式下忽略 host_value
test/registered/radix_cache/test_unified_radix_hicache_kl.py(模块 集成测试;类别 test;类型 test-coverage): 轻微调整一行,可能是取消跳过条件
关键符号:_match_prefix_helper, _match_prefix_helper_readonly, _all_valid, _collect_new_prefix_indices, create_match_validator, finalize_match_result, init_load_back
关键源码片段
python/sglang/srt/mem_cache/unified_radix_cache.py
核心匹配逻辑重构:分离 best_match_node 与 best_match_device_node,删除未使用的 readonly helper,新增 _all_valid 生成器检查,重写 init_load_back 的前缀索引收集
def _match_prefix_helper(
self, key: RadixKey
) -> tuple[list[torch.Tensor], UnifiedTreeNode, UnifiedTreeNode, int]:
# 新设计:在 HiCache 模式下分为全量 validators(允许 host-backed 节点)
# 和设备 validators(仅允许 device 驻留节点),用于分别追踪
# best_match_node(允许 host)和 best_match_device_node(仅 device)。
node = self.root_node
child_key = key.child_key(self.page_size)
value: list[torch.Tensor] = []
best_match_node = node
best_match_device_node = node
best_match_device_value_len = 0
separate_device_match = self.cache_controller is not None
if separate_device_match:
validators = tuple(
comp.create_match_validator() for comp in self._components_tuple
)
device_validators = tuple(
comp.create_match_validator(match_device_only=True)
for comp in self._components_tuple
)
else:
# 非 HiCache 模式:全量 validators == 设备 validators
validators = tuple(
comp.create_match_validator(match_device_only=True)
for comp in self._components_tuple
)
device_validators = validators
def _all_valid(validators, node):
# 使用生成器表达式允许短路,避免列表分配
return all(v(node) for v in validators)
while len(key) > 0 and child_key in node.children:
child = node.children[child_key]
# 死节点(evicted 且无 host backup)停止遍历
if child.evicted and not child.backuped:
break
prefix_len = child.key.match(key, page_size=self.page_size)
if prefix_len < len(child.key):
# 部分匹配:不 split,停止
break
if not child.evicted:
value.append(child.component_data[BASE_COMPONENT_TYPE].value)
node = child
if _all_valid(validators, node):
best_match_node = node
best_value_len = len(value) # 对应旧的 best_value_len
if separate_device_match and _all_valid(device_validators, node):
best_match_device_node = node
best_match_device_value_len = len(value)
# 注意:非 HiCache 下同时更新 best_match_device_node 由调用者负责
key = key[prefix_len:]
if len(key):
child_key = key.child_key(self.page_size)
return value, best_match_node, best_match_device_node, best_match_device_value_len
python/sglang/srt/mem_cache/unified_cache_components/mamba_component.py
create_match_validator 加入 match_device_only 参数;finalize_match_result 中禁用 HiCache 下的 branching seqlen
def create_match_validator(
self, match_device_only: bool = False
) -> Callable[[UnifiedTreeNode], bool]:
ct = self.component_type
if match_device_only:
# 仅设备模式:只检查 device 上的 value 是否存在
return lambda node: node.component_data[ct].value is not None
# HiCache 模式:evicted + backuped(host_value 存在)也是有效匹配
return lambda node: (
node.component_data[ct].value is not None
or node.component_data[ct].host_value is not None
)
def finalize_match_result(
self, result, params, value_chunks, best_value_len
) -> MatchResult:
# HiCache 可以继续使用前缀匹配并加载 host-backed 的 Mamba 状态。
# 暂时跳过 branching-state 填充,后续可添加 HiCache 感知的分支策略。
if self.cache.cache_controller is None and len(value_chunks) > best_value_len:
# 非 HiCache:计算 branching_seqlen
chunk_size = get_global_server_args().mamba_cache_chunk_size
aligned_seqlen = (
sum(len(v) for v in value_chunks) // chunk_size
) * chunk_size
branching_seqlen = aligned_seqlen if aligned_seqlen > 0 else None
else:
branching_seqlen = None
# ... 继续处理 cow_mamba
评论区精华
Review 中 gemini-code-assist[bot] 提出了三个中等优先级的优化建议:
风险与影响
- 风险:
1) 核心路径变更:_match_prefix_helper 是每次请求前缀匹配的必经函数,重构后语义分离可能引入回归,但新增的单元测试覆盖了主要场景。
2) Mamba branching seqlen 禁用:当 HiCache 激活且匹配到 host-backed 节点时,Mamba 的 branching 状态填充被跳过,可能影响 Mamba 模型的缓存跳跃性能,但该路径在 HiCache 下原本就不稳定,禁用更为安全。
3) 组件 validator 的一致性:所有四个组件(Full、Mamba、SWA、TreeComponent)的 create_match_validator 签名同步修改,若新增组件未实现 match_device_only 将引发 TypeError(抽象方法)。
- 影响:影响范围:使用 HiCache 且启用了 Mamba 或 SWA 的用户。正确性收益:修复了之前可能出现的
last_device_node 错误导致 prefix indices 包含 host-only 节点的问题;对非 HiCache 用户无行为影响。性能影响:新增的双 validators 遍历在 HiCache 路径上有微小开销,但可通过 generator 优化缓解。
- 风险标记:核心路径变更, Mamba 分支禁用, 缺少非 HiCache 回归测试(但已有覆盖)
关联脉络
- PR #25348 [UnifiedTree]: Add nightly hicache ci for dsa model: 同一功能线(HiCache)的 CI 覆盖,与此 PR 共同巩固 HiCache 稳定性
- PR #25252 [Lint] Fix
Optional[X] = (None,) typo defaults in two dataclasses: 修复了 hicache_storage 的默认值 typo,与此 PR 涉及的 HiCache 存储层相关
参与讨论