Prhub

#26810 Add KV-canary SWA + DeepSeek-V4 pool support

原始 PR 作者 fzyzcjy 合并时间 2026-05-31 09:55 文件变更 10 提交数 1 评论 6 代码增减 +334 / -3

执行摘要

扩展 KV-canary 支持 SWA 和 DeepSeek-V4 KV 池

PR body中提到:'Layer sliding-window-attention (SWA) and DeepSeek-V4 KV-pool support onto the MHA-only canary core' 目的是扩展KV-canary的支持范围,使其能监视SWA和DeepSeek-V4模型的KV缓存。

建议在合并前处理导入兼容性问题(使用try-except包装DeepSeekV4TokenToKVPool的导入)并修复拼写错误。该PR的适配器模式设计清晰,值得后续扩展时参考。

讨论亮点

gemini-code-assist[bot] 提出了若干建议:

  • api.py中应将DeepSeekV4TokenToKVPool的导入放在try-except中,以避免在不支持DeepSeek的环境中启动崩溃。
  • mode_config.py中的模型路径拼写错误(gemma-4-E2B-it应改为gemma-2-2b-it)。
  • dsv4.py中应添加对pool.swa_kv_pool为None的防御性检查。
  • swa.py中应检查sub_poolk_buffer的有效性。
  • fixtures.py中应使用b.shape[1:].numel()代替b[0]以避免空张量时的IndexError。
  • e2e_base.py中的注释引用了错误的模型名。
    截止PR合并时,这些建议是否被采纳尚不明确(diff中未体现修改)。

实现拆解

  1. 新增适配器模块:在swa.py中实现attach_swa,通过_build_subpool_group分别为full和SWA子池创建CanaryBufferGroup;在dsv4.py中实现attach_dsv4,仅处理SWA子池(其他子池未覆盖)。
  2. 注册适配器:更新api.py,导入新的适配器函数,并在_POOL_ATTACHERS字典中添加SWAKVPool: attach_swaDeepSeekV4TokenToKVPool: attach_dsv4
  3. 测试支撑:在fixtures.py中添加FakeSWAPoolFakeSwaSubPool模拟类,以及make_swa_pool工厂方法;在consts.py中定义SWA和DSV4的服务器参数和环境变量。
  4. 单元测试:在test_self_unit_pool_patcher.py中新增test_canary_buffer_group_allocate_full_and_swatest_swa_attach_splices_full_into_contiguous_and_swa_into_state,验证缓冲区组分配和打补丁后的buf_info。
  5. 端到端测试:修改e2e_base.pymode_config.py,添加'swa'和'dsv4'模式;新增test_self_e2e_baseline.py中的TestBaselineSwa类和test_self_e2e_baseline_dsv4.py
文件 模块 状态 重要度
python/sglang/srt/kv_canary/pool_patcher/adapters/swa.py KV-canary added 8.08
python/sglang/srt/kv_canary/pool_patcher/adapters/dsv4.py KV-canary added 7.54
python/sglang/test/kv_canary/fixtures.py 测试工具 modified 7.24

关键符号

attach_swa attach_dsv4 _build_subpool_group make_swa_pool

关键源码片段

python/sglang/srt/kv_canary/pool_patcher/adapters/swa.py core-logic

核心适配器,实现 SWA 池的 canary 缓冲区附加逻辑

from __future__ import annotations
from typing import Optional
import torch
from sglang.srt.kv_canary.buffer_group import CanaryBufferGroup, PoolKind
from sglang.srt.kv_canary.pool_patcher.buf_info_splice import patch_buf_info_method
from sglang.srt.kv_canary.pool_patcher.buffer_alloc import alloc_canary_bufdef attach_swa(
    *,
    pool: object,
    device: torch.device,
    kv_token_id_vs_position_offset: int,
) -> tuple[CanaryBufferGroup, ...]:
    # 为 full 子池构建金丝雀缓冲区组(无 swa_index_lut)
    full_group = _build_subpool_group(
        sub_pool=pool.full_kv_pool,
        kind=PoolKind.FULL,
        device=device,
        swa_lut=None,
        kv_token_id_vs_position_offset=kv_token_id_vs_position_offset,
    )
    # 为 SWA 子池构建金丝雀缓冲区组(携带 swa_index_lut)
    swa_group = _build_subpool_group(
        sub_pool=pool.swa_kv_pool,
        kind=PoolKind.SWA,
        device=device,
        swa_lut=pool.full_to_swa_index_mapping,
        kv_token_id_vs_position_offset=kv_token_id_vs_position_offset,
    )
    # 打补丁 full 组的 buf_info 到 get_contiguous_buf_infos
    patch_buf_info_method(
        pool,
        method_name="get_contiguous_buf_infos",
        group=full_group,
        has_v_half=True,
        page_size=pool.page_size,
    )
    # 打补丁 SWA 组的 buf_info 到 get_state_buf_infos
    patch_buf_info_method(
        pool,
        method_name="get_state_buf_infos",
        group=swa_group,
        has_v_half=True,
        page_size=pool.page_size,
    )
    return (full_group, swa_group)def _build_subpool_group(
    *,
    sub_pool: object,
    kind: PoolKind,
    device: torch.device,
    swa_lut: Optional[torch.Tensor],
    kv_token_id_vs_position_offset: int,
) -> CanaryBufferGroup:
    # 从子池的 k_buffer 第一维获取 slot 数量
    num_slots = int(sub_pool.k_buffer[0].shape[0])
    # 分配 4 个金丝雀缓冲区:k_head, k_tail, v_head, v_tail
    k_head = alloc_canary_buf(num_slots=num_slots, device=device)
    k_tail = alloc_canary_buf(num_slots=num_slots, device=device)
    v_head = alloc_canary_buf(num_slots=num_slots, device=device)
    v_tail = alloc_canary_buf(num_slots=num_slots, device=device)
    return CanaryBufferGroup(
        kind=kind,
        k_head=k_head,
        k_tail=k_tail,
        v_head=v_head,
        v_tail=v_tail,
        swa_index_lut=swa_lut,
        kv_token_id_vs_position_offset=kv_token_id_vs_position_offset,
    )
python/sglang/srt/kv_canary/pool_patcher/adapters/dsv4.py core-logic

核心适配器,实现 DeepSeek-V4 池的 canary 缓冲区附加(仅覆盖 SWA 子池)

from __future__ import annotations
import torch
from sglang.srt.kv_canary.buffer_group import CanaryBufferGroup, PoolKind
from sglang.srt.kv_canary.pool_patcher.buf_info_splice import patch_buf_info_method
from sglang.srt.kv_canary.pool_patcher.buffer_alloc import alloc_canary_bufdef attach_dsv4(
    *,
    pool: object,
    device: torch.device,
    kv_token_id_vs_position_offset: int,
) -> tuple[CanaryBufferGroup, ...]:
    """Attach canary buffers to a DeepSeekV4TokenToKVPool.
    TODO: only the swa_kv_pool sub-pool is wired; c4_kv_pool / c128_kv_pool /
    c4_indexer_kv_pool / compress state pools are left uncovered.
    """
    # 目前仅处理 SWA 子池
    sub_pool = pool.swa_kv_pool
    num_slots = int(sub_pool.size)
    # 分配 K 的金丝雀缓冲区(V 暂不分配,has_v_half=False)
    k_head = alloc_canary_buf(num_slots=num_slots, device=device)
    k_tail = alloc_canary_buf(num_slots=num_slots, device=device)
    # 创建 SWA 类型的金丝雀缓冲区组,无 V 缓冲区
    group = CanaryBufferGroup(
        kind=PoolKind.SWA,
        k_head=k_head,
        k_tail=k_tail,
        v_head=None,
        v_tail=None,
        swa_index_lut=pool.full_to_swa_index_mapping,
        kv_token_id_vs_position_offset=kv_token_id_vs_position_offset,
    )
    # 打补丁到 get_state_buf_infos
    patch_buf_info_method(
        pool,
        method_name="get_state_buf_infos",
        group=group,
        has_v_half=False,
        page_size=sub_pool.page_size,
    )
    return (group,)
python/sglang/test/kv_canary/fixtures.py test-coverage

测试 fixtures,提供 FakeSWAPool 等模拟类,支持单元测试

@dataclass
class FakeSwaSubPool:
    k_buffer: List[torch.Tensor]
    v_buffer: List[torch.Tensor]@dataclass
class FakeSWAPool:
    full_kv_pool: object
    swa_kv_pool: object
    full_to_swa_index_mapping: torch.Tensor
    page_size: int = 1
​
    def get_contiguous_buf_infos(self):
        # 返回 full_kv_pool 的缓冲区信息
        return _kv_buf_infos(
            k_buffer=self.full_kv_pool.k_buffer,
            v_buffer=self.full_kv_pool.v_buffer,
            page_size=self.page_size,
        )
​
    def get_state_buf_infos(self):
        # 返回 swa_kv_pool 的缓冲区信息
        return _kv_buf_infos(
            k_buffer=self.swa_kv_pool.k_buffer,
            v_buffer=self.swa_kv_pool.v_buffer,
            page_size=self.page_size,
        )def _kv_buf_infos(*, k_buffer, v_buffer, page_size) -> tuple:
    # 通用函数:计算指针、nbytes、item_lens
    ptrs = [b.data_ptr() for b in k_buffer] + [b.data_ptr() for b in v_buffer]
    lens = [b.nbytes for b in k_buffer] + [b.nbytes for b in v_buffer]
    # 注意:使用 b.shape[1:].numel() 而非 b[0] 以支持空张量
    item_lens = [b.shape[1:].numel() * b.element_size() * page_size for b in k_buffer] + \
                [b.shape[1:].numel() * b.element_size() * page_size for b in v_buffer]
    return ptrs, lens, item_lensdef make_swa_pool(
    device: torch.device = DEFAULT_DEVICE,
    *,
    full_slots: int = 16,
    swa_slots: int = 8,
    dim: int = 8,
    layer_num: int = 1,
) -> FakeSWAPool:
    # 创建 full 和 swa 子池的模拟数据
    full = FakeSwaSubPool(
        k_buffer=[torch.zeros(full_slots, dim, dtype=torch.float16, device=device) for _ in range(layer_num)],
        v_buffer=[torch.zeros(full_slots, dim, dtype=torch.float16, device=device) for _ in range(layer_num)],
    )
    swa = FakeSwaSubPool(
        k_buffer=[torch.zeros(swa_slots, dim, dtype=torch.float16, device=device) for _ in range(layer_num)],
        v_buffer=[torch.zeros(swa_slots, dim, dtype=torch.float16, device=device) for _ in range(layer_num)],
    )
    # LUT: 前 swa_slots 映射,其余为 -1
    lut = torch.full((full_slots + 1,), -1, dtype=torch.int64, device=device)
    lut[:swa_slots] = torch.arange(swa_slots, dtype=torch.int64, device=device)
    return FakeSWAPool(full_kv_pool=full, swa_kv_pool=swa, full_to_swa_index_mapping=lut)

评论区精华

顶级导入 DeepSeekV4TokenToKVPool 可能导致启动失败 安全

gemini-code-assist[bot] 建议将导入包装在 try-except 中,以防止在不支持 DeepSeek 的环境中崩溃。

结论:PR 已合并但 diff 中未体现修改,状态未解决。 · unresolved

模型路径拼写错误 正确性

gemini-code-assist[bot] 指出 `mode_config.py` 中的 `gemma-4-E2B-it` 应为 `gemma-2-2b-it`。

结论:未修改,状态未解决。 · unresolved

缺少子池 None 检查 正确性

gemini-code-assist[bot] 建议在 dsv4.py 和 swa.py 中添加对 sub_pool 为 None 的防御性检查。

结论:未修改,状态未解决。 · unresolved

空张量索引风险 测试

gemini-code-assist[bot] 建议在 fixtures.py 中使用 `b.shape[1:].numel()` 避免空张量时的 IndexError。

结论:未修改,状态未解决。 · unresolved

风险与影响

  1. api.py中顶级导入DeepSeekV4TokenToKVPool可能导致在不支持DeepSeek的平台上启动失败。
  2. dsv4.pyswa.py缺少对子池(swa_kv_poolk_buffer)为None或空时的防御性检查,可能引发AttributeErrorIndexError
  3. 测试配置中的模型路径错误(gemma-4-E2B-it不存在)会导致端到端测试失败。
  4. 测试fixtures中使用b[0]full_slotsswa_slots为0时会触发IndexError

对用户:若使用SWA或DeepSeek-V4模型并启用KV-canary,将自动获得KV缓存监控,无需额外配置。对系统:仅在启动时注册适配器,运行时无性能影响。对团队:需维护适配器代码,确保与底层池接口同步。

核心路径变更 环境兼容性 缺少防御检查 测试配置错误

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论