Prhub

#23675 perf: add --prefill-only-disable-kv-cache to skip KV pool allocation

原始 PR 作者 jasperjiaguo 合并时间 2026-05-12 04:10 文件变更 4 提交数 9 评论 19 代码增减 +347 / -1

执行摘要

跳过 KV 缓存池分配,节省显存并提升吞吐

在纯 prefill 工作负载中,注意力通过 flash_attn_varlen_func 使用原始 K/V,无需读写物理 KV 缓存。但之前引擎总是分配巨大的 KV 缓冲区,造成显存浪费。该标志消除了这种浪费。

建议精读。该 PR 的设计模式(no-op pool 子类保持接口兼容)有参考价值。对于 embedding 服务用户,建议启用该标志以获得显存收益。代码结构清晰,测试完善(8 种组合),值得团队内部学习。

讨论亮点

reviewer hzh0425 指出该标志初期不兼容 context-parallel prefill(因为 CP 路径会调用 set_kv_buffer),且仅支持 MHA 池。作者积极响应,添加了多项检查:在 ServerArgs 中拒绝 attn_cp_size > 1enable_prefill_context_parallel,在 _validate_prefill_only_disable_kv_cache_pool_family 中列出不支持的池族并给出清晰错误。经过多轮迭代后,hzh0425 给予了批准。

实现拆解

  1. 新增 prefill_only_disable_kv_cache 布尔参数到 ServerArgs(位于 python/sglang/srt/server_args.py),并添加 CLI 标志。
  2. 实现 NoOpMHATokenToKVPool 子类(位于 python/sglang/srt/mem_cache/memory_pool.py),重写 _create_buffers 分配极小占位符(形状 [page_size, head_num, head_dim]),重写 _finalize_allocation_log 报告零内存使用,重写 get_kv_size_bytes 返回零,重写 set_kv_buffer 抛出 RuntimeError 以防止误用。
  3. ServerArgs.__post_init__ 中调用验证方法,检查使用条件:必须设置 --is-embedding--chunked-prefill-size=-1--disable-radix-cache,且不兼容 context parallel、HiSparse、FP4 KV 缓存等。
  4. ModelRunnerKVCacheMixin._init_pools 中,根据标志选择 NoOpMHATokenToKVPool 而非 MHATokenToKVPool,并调用 _validate_prefill_only_disable_kv_cache_pool_family 拒绝不支持的池族(如 MLA、SWA、Mamba 等)。
  5. 添加单元测试(test/registered/unit/server_args/test_server_args.py)验证各种合法与非法配置组合。
文件 模块 状态 重要度
python/sglang/srt/mem_cache/memory_pool.py 内存池 modified 8.37
python/sglang/srt/server_args.py 参数配置 modified 7.63
python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py 运行池 modified 7.45
test/registered/unit/server_args/test_server_args.py 单元测试 modified 6.73

关键符号

NoOpMHATokenToKVPool _create_buffers _finalize_allocation_log get_kv_size_bytes set_kv_buffer _validate_prefill_only_disable_kv_cache_args _handle_prefill_only_disable_kv_cache _validate_prefill_only_disable_kv_cache_pool_family test_valid_minimal_config_constructs test_rejects_when_not_embedding test_rejects_when_chunked_prefill_size_not_minus_one test_rejects_when_radix_cache_enabled test_rejects_attn_cp_size_greater_than_one test_rejects_prefill_context_parallel

关键源码片段

python/sglang/srt/mem_cache/memory_pool.py core-logic

核心实现:新增 `NoOpMHATokenToKVPool` 子类,通过极小占位符替代 GB 级 KV 缓冲区,是跳过缓存分配的关键。

class NoOpMHATokenToKVPool(MHATokenToKVPool):
    """KV cache pool that skips physical K/V buffer allocation.    在 embedding 模式 prefill-only 工作负载中使用 FA backend 的
    fa_skip_kv_cache 路径时,attention 通过 flash_attn_varlen_func
    使用原始 K/V,无需读写池。该类保持调度器的容量视图,
    但仅分配 (page_size, head_num, head_dim) 占位符。
    """
​
    def _create_buffers(self):
        # 分配极小占位符。形状为 [page_size, head_num, head_dim] 每层,
        # 使得 FA backend 顶部的 view 操作无条件成功。
        with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
            self.k_buffer = [
                torch.zeros(
                    (self.page_size, self.head_num, self.head_dim),
                    dtype=self.store_dtype,
                    device=self.device,
                )
                for _ in range(self.layer_num)
            ]
            self.v_buffer = [
                torch.zeros(
                    (self.page_size, self.head_num, self.v_head_dim),
                    dtype=self.store_dtype,
                    device=self.device,
                )
                for _ in range(self.layer_num)
            ]
        self.k_data_ptrs = torch.tensor(
            [x.data_ptr() for x in self.k_buffer],
            dtype=torch.uint64,
            device=self.device,
        )
        self.v_data_ptrs = torch.tensor(
            [x.data_ptr() for x in self.v_buffer],
            dtype=torch.uint64,
            device=self.device,
        )
        self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0)
        self.data_strides = torch.tensor(
            [np.prod(x.shape[1:]) * x.dtype.itemsize for x in self.k_buffer + self.v_buffer],
            device=self.device,
        )
​
    def _finalize_allocation_log(self, num_tokens: int):
        self.mem_usage = 0.0
        placeholder_bytes = (2 * self.layer_num * self.page_size * self.head_num
                             * max(self.head_dim, self.v_head_dim) * self.store_dtype.itemsize)
        logger.info(
            f"KV Cache skipped (no-op pool). Logical #tokens: {num_tokens}, "
            f"physical K/V size: ~{placeholder_bytes / 1024:.1f} KB placeholder"
        )
​
    def get_kv_size_bytes(self):
        # 报告零,使下游内存核算反映真实情况。
        return (0, 0)
​
    def set_kv_buffer(self, *args, **kwargs):
        raise RuntimeError(
            "NoOpMHATokenToKVPool.set_kv_buffer was called. This pool is only "
            "valid in prefill-only modes (e.g. --is-embedding, scoring) with "
            "the FA backend's fa_skip_kv_cache path active; the attention "
            "backend must never write to it."
        )
python/sglang/srt/server_args.py core-logic

参数定义与条件验证:新增 CLI 标志并实现两组前置检查,确保该标志仅在合法配置下使用。

def _validate_prefill_only_disable_kv_cache_args(self):
    """为 --prefill-only-disable-kv-cache 执行标志/前置条件约束验证。
    在 dummy-model 短路之前运行,以便及早拒绝错误配置。
    """
    if not self.prefill_only_disable_kv_cache:
        return
​
    # 目前限定为 embedding 模式,后续可扩展到其他 prefill-only 工作负载。
    if not self.is_embedding:
        raise ValueError(
            "--prefill-only-disable-kv-cache currently requires --is-embedding. "
            "Other prefill-only workloads may be supported in a future change."
        )
    if self.kv_cache_dtype == "fp4_e2m1":
        raise ValueError(
            "--prefill-only-disable-kv-cache does not support --kv-cache-dtype=fp4_e2m1."
        )
    # 结构前提:chunked_prefill_size == -1 且 disable_radix_cache。
    if self.chunked_prefill_size != -1:
        raise ValueError("--prefill-only-disable-kv-cache requires --chunked-prefill-size=-1.")
    if not self.disable_radix_cache:
        raise ValueError("--prefill-only-disable-kv-cache requires --disable-radix-cache.")
​
    # 不兼容 context parallel prefill(CP 路径会调用 set_kv_buffer)。
    if self.attn_cp_size is not None and self.attn_cp_size > 1:
        raise ValueError("--prefill-only-disable-kv-cache is incompatible with --attn-cp-size>1.")
    if self.enable_prefill_context_parallel:
        raise ValueError("--prefill-only-disable-kv-cache is incompatible with --enable-prefill-context-parallel.")

评论区精华

不兼容 context-parallel prefill 和缺少池族保护 设计

reviewer hzh0425 指出标志与 context-parallel prefill 不兼容(CP 路径调用 set_kv_buffer),且仅支持 MHA 池。建议在 ServerArgs 中拒绝 CP 配置,并添加池族检查。

结论:作者接受反馈,在 ServerArgs 中拒绝 attn_cp_size > 1 和 enable_prefill_context_parallel,并在 _init_pools 中添加 _validate_prefill_only_disable_kv_cache_pool_family 明确列出不支持的池族。 · 已解决

函数拆分与代码组织 style

reviewer hzh0425 建议将过长的验证逻辑提取为独立函数。

结论:作者将 __post_init__ 中的内联验证拆分为 `_validate_prefill_only_disable_kv_cache_args` 和 `_handle_prefill_only_disable_kv_cache` 两个方法。 · 已解决

风险与影响

  1. 如果用户在非 prefill-only 模式下错误启用标志,set_kv_buffer 会抛出 RuntimeError,导致请求失败——这比静默数据损坏更安全,但仍可能影响服务可用性。
  2. 标志目前仅支持 CUDA FA 后端(fa3/fa4),其他后端如 NPU、AMD 等未经过测试,可能在启用时启动失败。
  3. 代码修改集中在核心内存池和调度逻辑,可能影响未来其他工作负载的兼容性。
  4. 占位符形状依赖 page_size,若 FA 后端未来修改 view 操作,可能引发不匹配。

对 embedding/scoring 用户是显著改进:显存释放数十 GB,吞吐提升约 6%。默认未启用,对现有用户无影响。新增标志文档化后,有利于 embedding 服务的部署成本优化。团队需要确保该标志与其他新增特性(如 NSA、MLA)的兼容性。

错误启用导致请求失败 仅支持 CUDA FA 后端 依赖特定参数组合 核心内存池修改影响面广

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论