Prhub

#25022 [Bugfix, NSA HiCache] Fix missing override_kv_cache_dim in attach_hybrid_nsa_pool_to_hiradix_cache

原始 PR 作者 taegeonum 合并时间 2026-05-13 11:45 文件变更 3 提交数 12 评论 1 代码增减 +12 / -24

执行摘要

修复 NSA HiCache 中缺失的 override_kv_cache_dim 参数

根据 PR body,由于在重构 hybrid_pool_assembler.py(PR#23243)时遗漏参数,导致运行时出现 RuntimeError: The size of tensor a (576) must match the size of tensor b (656) at non-singleton dimension 2。修复此问题以恢复 NSA/DeepSeek 模型在 HiCache 下的正常推理。

值得精读,展示如何从重构中引入细微 bug 并修复,同时进行接口清理。关注 override_kv_cache_dim 在共享锚点中的必要性,以及改为传递通信组对象的设计思路。

讨论亮点

无实质性讨论,PR 由 hzh0425 审批通过,未提出额外问题。

实现拆解

  1. attach_hybrid_nsa_pool_to_hiradix_cache 函数中,向 build_shared_anchor_stack 调用添加 override_kv_cache_dim=kv.kv_cache_dim 参数。
  2. build_kv_only_stackbuild_hybrid_swa_stackbuild_hybrid_mamba_stackbuild_shared_anchor_stack 中移除不再需要的 attn_cp_rankattn_cp_size 参数。
  3. 在上述函数中新增 attn_cp_groupattn_tp_group 参数,并在构造 HybridCacheController 时传递它们,替换原来的 rank/size 参数。
  4. unified_radix_cache.pyinit_hicache 方法中,调用 attach_hybrid_pool_to_unified_cache 时传递 attn_cp_group=params.attn_cp_cache_groupattn_tp_group=params.attn_tp_cache_group,确保通信组信息正确传播。
  5. hybrid_cache_controller.py 中移除 __init__attn_cp_rankattn_cp_size 参数,这些信息已由通信组对象直接携带。未直接添加测试文件,改动依赖回归测试。
文件 模块 状态 重要度
python/sglang/srt/mem_cache/hybrid_cache/hybrid_pool_assembler.py 缓存层 modified 6.19
python/sglang/srt/mem_cache/hybrid_cache/hybrid_cache_controller.py 缓存层 modified 4.82
python/sglang/srt/mem_cache/unified_radix_cache.py 缓存层 modified 4.82

关键符号

attach_hybrid_nsa_pool_to_hiradix_cache build_kv_only_stack build_hybrid_swa_stack build_hybrid_mamba_stack build_shared_anchor_stack HybridCacheController.__init__ UnifiedRadixCache.init_hicache

关键源码片段

python/sglang/srt/mem_cache/hybrid_cache/hybrid_pool_assembler.py core-logic

核心修复:在 attach_hybrid_nsa_pool_to_hiradix_cache 中添加缺失的 override_kv_cache_dim 参数;同时清理多个 build_*_stack 函数的参数,使用 attn_cp_group/attn_tp_group 替代 attn_cp_rank/attn_cp_size。

def build_kv_only_stack(
    *,
    params: CacheInitParams,
    server_args: ServerArgs,
    kv_pool: Any,
    full_layer_mapping: dict[int, int],
    page_size: int,
    tp_group,
    load_cache_event,
    # 新增参数:替换了旧的 attn_cp_rank 和 attn_cp_size
    attn_cp_group: Optional["torch.distributed.ProcessGroup"] = None,
    attn_tp_group: Optional["torch.distributed.ProcessGroup"] = None,
    storage_backend: Optional[str],
    use_mla: bool,
    override_kv_cache_dim: Optional[int] = None, # 可覆盖 kv 缓存维度
    prefetch_threshold: int = 256,
    model_name: Optional[str] = None,
    storage_backend_extra_config: Optional[dict] = None,
    pp_rank: int = 0,
    pp_size: int = 1,
    enable_storage_metrics: bool = False,
) -> tuple[HostPoolGroup, HybridCacheController]:
    transfer_layer_num = len(full_layer_mapping)
    kv_host_pool = build_kv_host_pool(
        kv_pool=kv_pool, page_size=page_size,
        server_args=server_args, use_mla=use_mla,
        override_kv_cache_dim=override_kv_cache_dim,
    )
    entries = [
        build_pool_entry(
            name=PoolName.KV, host_pool=kv_host_pool,
            device_pool=kv_pool, layer_mapping=full_layer_mapping,
            transfer_layer_num=transfer_layer_num, is_anchor=True,
        )
    ]
    host_pool_group = HostPoolGroup(entries)
    cache_controller = HybridCacheController(
        params.token_to_kv_pool_allocator, host_pool_group,
        page_size, tp_group,
        load_cache_event=load_cache_event,
        attn_cp_group=attn_cp_group, # 传递通信组对象而非 rank
        attn_tp_group=attn_tp_group,
        write_policy=server_args.hicache_write_policy,
        io_backend=server_args.hicache_io_backend,
        storage_backend=storage_backend,
        prefetch_threshold=prefetch_threshold,
        model_name=model_name,
        storage_backend_extra_config=storage_backend_extra_config,
        pp_rank=pp_rank, pp_size=pp_size,
        transfer_layer_num=transfer_layer_num,
        enable_storage_metrics=enable_storage_metrics,
    )
    return host_pool_group, cache_controller
python/sglang/srt/mem_cache/unified_radix_cache.py core-logic

在 init_hicache 中调用 attach_hybrid_pool_to_unified_cache 时传递 attn_cp_group 和 attn_tp_group,确保通信组信息传递给下级构造。

    def init_hicache(self, server_args: ServerArgs, params: CacheInitParams) -> None:
        """Initialize HiCache infrastructure."""
        from sglang.srt.mem_cache.hybrid_cache.hybrid_pool_assembler import (
            attach_hybrid_pool_to_unified_cache,
        )
​
        # Direct IO layout fixup (must happen before pool creation)
        if server_args.hicache_io_backend == "direct":
            if server_args.hicache_mem_layout == "page_first":
                server_args.hicache_mem_layout = "page_first_direct"
                logger.warning(
                    "Page first layout is not supported with direct IO backend, "
                    "switching to page first direct layout"
                )
​
        self.load_cache_event = threading.Event()
        self.hicache_anchor_kv_shared_indices_pools.clear()
        attach_hybrid_pool_to_unified_cache(
            self,
            params,
            server_args,
            load_cache_event=self.load_cache_event,
            # 新增传递通信组参数,替换之前的 attn_cp_rank/attn_cp_size
            attn_cp_group=params.attn_cp_cache_group,
            attn_tp_group=params.attn_tp_cache_group,
        )
​
        # State initialization
        self.write_through_threshold = (
            1 if server_args.hicache_write_policy == "write_through" else 2
        )
        self.load_back_threshold = 256
​
        logger.info(
            f"HiCache D\u2194H initialized: "
            f"host_pool_size={self.host_pool_group.size}, "
            f"write_policy={server_args.hicache_write_policy}, "
            f"tp_world_size={self.tp_world_size}, "
            f"transfer_layer_num={self.cache_controller.layer_num}"
        )

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

核心修复仅涉及一行参数添加,改动明显,回归风险低。参数清理可能影响其他未发现的调用方式,但所有修改点在代码中一致更新。主要风险是缺少针对 NSA+HiCache 的组合测试覆盖,若未来有类似重构容易再次遗漏。

影响所有使用 NSA/DeepSeek 模型并启用 HiCache 的用户,修复了运行时崩溃。参数清理统一了通信组传递方式,简化了接口,对开发者友好。影响范围集中在 HiCache 路径,不涉及其他功能。

核心路径变更 缺少测试覆盖

关联 Issue

#23243 [Hybrid-Cache]: Refactor hybrid_pool_assembler.py

完整报告

参与讨论