Prhub

#24026 [SWA] Fix missing mamba_indices parameter in cpu copy interface

原始 PR 作者 merrymercy 合并时间 2026-04-30 08:33 文件变更 7 提交数 6 评论 3 代码增减 +41 / -37

执行摘要

修复 SWA 模型中 get_cpu_copy 缺少 mamba_indices 参数导致的崩溃

修复 TypeError: SWATokenToKVPoolAllocator.get_cpu_copy() got an unexpected keyword argument mamba_indices 崩溃(当 SWA 模型触发 retract_decode 路径时)。PR #22493 添加 mamba_indices kwarg 到 offload_kv_cache/load_kv_cache 调用方并更新了 TokenToKVPoolAllocator/PagedTokenToKVPoolAllocator,但遗漏了 SWATokenToKVPoolAllocator(位于单独文件中)。

推荐阅读。该 PR 展示了一个典型的接口扩展之后遗漏子类导致的 bugfix 过程。设计决策上,作者选择显式参数而非 **kwargs,提升了代码可读性和类型安全性。值得关注的是如何系统性地扫描整个类层次结构并统一修改。

讨论亮点

评审机器人(gemini-code-assist[bot])在第一个 commit 的版本上指出:swa_memory_pool.py 中添加的 **kwargs 并未转发到底层 full_kv_poolswa_kv_pool 的调用,导致额外参数被丢弃。PR 随后迭代将 **kwargs 替换为显式 mamba_indices=None 并在所有调用链中传递,使得该问题自然解决。最终版本中不再有未转发的 kwargs。

实现拆解

实现拆解:

  1. 快速修复:在 swa_memory_pool.pySWAKVPoolSWATokenToKVPoolAllocator 方法中添加 **kwargs 以匹配基类接口,但未转发(review 指出)。
  2. 替换为显式参数:将整个类层次结构中的 get_cpu_copy/load_cpu_copy 签名从 **kwargs 改为显式 mamba_indices=None 参数,包括:
    • BaseTokenToKVPoolAllocatorallocator.py
    • TokenToKVPoolAllocatorPagedTokenToKVPoolAllocatorallocator.py
    • KVCachememory_pool.py,基类及其子类 MHATokenToKVPoolMLATokenToKVPoolHybridLinearKVPool
    • SWAKVPoolSWATokenToKVPoolAllocatorswa_memory_pool.py
    • HiSparseDevicePoolhisparse_memory_pool.py,虽然 raise NotImplementedError 但保持接口一致)
    • MambaPoolmemory_pool.py,移除未使用的 **kwargs
  3. 格式化与合并:运行 black 格式化,并通过 merge 引入 main 分支同时修复 lint 问题。
  4. 遗漏补漏:最终 commit 确保 HiSparseNSATokenToKVPool 的 cpu copy 方法也添加了参数。

测试方面,未新增专门测试用例,依赖现有 SWA 和 mamba 的 retraction 测试。

文件 模块 状态 重要度
python/sglang/srt/mem_cache/allocator.py 分配器 modified 6.72
python/sglang/srt/mem_cache/memory_pool.py 内存池 modified 6.63
python/sglang/srt/mem_cache/swa_memory_pool.py SWA 池 modified 6.45
python/sglang/srt/mem_cache/hisparse_memory_pool.py 稀疏池 modified 5.87
python/sglang/srt/model_loader/loader.py 模型加载 modified 5.46
python/sglang/srt/model_executor/model_runner.py 模型运行 modified 5.46
python/sglang/srt/server_args.py 启动参数 modified 5.22

关键符号

BaseTokenToKVPoolAllocator.get_cpu_copy TokenToKVPoolAllocator.get_cpu_copy PagedTokenToKVPoolAllocator.get_cpu_copy KVCache.get_cpu_copy MHATokenToKVPool.get_cpu_copy MLATokenToKVPool.get_cpu_copy HybridLinearKVPool.get_cpu_copy SWAKVPool.get_cpu_copy SWATokenToKVPoolAllocator.get_cpu_copy MambaPool.get_cpu_copy BaseTokenToKVPoolAllocator.load_cpu_copy TokenToKVPoolAllocator.load_cpu_copy PagedTokenToKVPoolAllocator.load_cpu_copy KVCache.load_cpu_copy MHATokenToKVPool.load_cpu_copy MLATokenToKVPool.load_cpu_copy HybridLinearKVPool.load_cpu_copy SWAKVPool.load_cpu_copy SWATokenToKVPoolAllocator.load_cpu_copy MambaPool.load_cpu_copy

关键源码片段

python/sglang/srt/mem_cache/allocator.py core-logic

核心分配器基类和 TokenToKV 分配器,定义并显式化 get_cpu_copy/load_cpu_copy 签名

# python/sglang/srt/mem_cache/allocator.py (head) 关键方法class BaseTokenToKVPoolAllocator:
    # ... 省略其他方法 ...
    def get_cpu_copy(self, indices, mamba_indices=None):
        # FIXME: 等 paged allocator 实现后复用 get_cpu_copy
        raise NotImplementedError()
​
    def load_cpu_copy(self, kv_cache_cpu, indices, mamba_indices=None):
        # FIXME: 等 paged allocator 实现后复用 load_cpu_copy
        raise NotImplementedError()class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
    # ... 省略其他方法 ...
    def get_cpu_copy(self, indices, mamba_indices=None):
        # 显式传递 `mamba_indices`,而非模糊的 `**kwargs`
        return self._kvcache.get_cpu_copy(indices, mamba_indices=mamba_indices)
​
    def load_cpu_copy(self, kv_cache_cpu, indices, mamba_indices=None):
        return self._kvcache.load_cpu_copy(
            kv_cache_cpu, indices, mamba_indices=mamba_indices
        )
python/sglang/srt/mem_cache/swa_memory_pool.py core-logic

SWA 特定 KV 池和分配器,是本次修复的核心遗漏点

# python/sglang/srt/mem_cache/swa_memory_pool.py (head) 关键方法class SWAKVPool:
    # ... 省略其他方法 ...
    def get_cpu_copy(self, indices, mamba_indices=None):
        # 复制 Full 和 SWA 两个池子的 KV cache
        full_kv_cpu = self.full_kv_pool.get_cpu_copy(indices)
​
        if self.full_to_swa_index_mapping is not None:
            swa_indices = self.full_to_swa_index_mapping[indices]
            swa_kv_cpu = self.swa_kv_pool.get_cpu_copy(swa_indices)
        else:
            swa_kv_cpu = None
​
        return {"full": full_kv_cpu, "swa": swa_kv_cpu}
​
    def load_cpu_copy(self, kv_cache_cpu, indices, mamba_indices=None):
        # 注意:这里的 indices 是新分配的索引,与 get_cpu_copy 的不同
        full_kv_cpu = kv_cache_cpu["full"]
        swa_kv_cpu = kv_cache_cpu["swa"]
​
        self.full_kv_pool.load_cpu_copy(full_kv_cpu, indices)
        if swa_kv_cpu is not None and self.full_to_swa_index_mapping is not None:
            swa_indices = self.full_to_swa_index_mapping[indices]
            self.swa_kv_pool.load_cpu_copy(swa_kv_cpu, swa_indices)class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
    # ... 省略其他方法 ...
    def get_cpu_copy(self, indices, mamba_indices=None):
        return self._kvcache.get_cpu_copy(indices, mamba_indices=mamba_indices)
​
    def load_cpu_copy(self, kv_cache_cpu, indices, mamba_indices=None):
        return self._kvcache.load_cpu_copy(
            kv_cache_cpu, indices, mamba_indices=mamba_indices
        )

评论区精华

**Kwargs 未转发到底层调用 正确性

评审机器人指出第一个 commit 在 swa_memory_pool.py 中添加的 **kwargs 没有传递给 self.full_kv_pool.get_cpu_copy 和 self.swa_kv_pool.get_cpu_copy,导致额外参数被丢弃。

结论:PR 后续将 **kwargs 替换为显式 `mamba_indices=None` 并在所有调用链中传递,问题自然解决。 · 已解决

风险与影响

风险较低。核心改动仅为接口参数从 **kwargs 转为显式参数,调用方已统一传递 mamba_indices。但存在以下潜在风险:

  • 若某些自定义调用方未传递 mamba_indices 且依赖旧签名,但 Python 支持关键字参数缺省为 None,因此向后兼容。
  • MambaPool 中原 **kwargs 被移除,若存在未发现的传递额外参数的地方会引发 TypeError,但检查后无实际传递。
  • 缺少对新接口的单元测试,回归风险主要依赖集成测试覆盖。

用户影响:修复了 SWA 模型(如 Llama4)在 retract_decode 路径上的崩溃,对非 SWA 模型无功能影响。
系统影响:接口签名明确化有助于静态分析和代码理解。
团队影响:需注意未来扩展 get_cpu_copy/load_cpu_copy 接口时,所有子类必须同步更新参数。

核心路径变更 缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论