Prhub

#24932 [PD] Refactor hybrid state transfer

原始 PR 作者 ispobock 合并时间 2026-05-12 13:16 文件变更 13 提交数 6 评论 8 代码增减 +594 / -377

执行摘要

重构 PD 状态传输以支持多状态类型扩展

引入多状态类型列表以增强扩展性,使后端可以迭代 state_types: List[StateType] 配合 List[List[X]] 字段,替代原来的单 state_type 和平铺 List[X]。单状态模型行为和之前一致,新增状态类型只需在 setup 侧追加一项。

值得精读,特别是对 PD 分布式推理实现感兴趣的开发者。该 PR 通过引入枚举和列表循环,巧妙地消除了多状态转移中的大量 if-elif 判断,使添加新状态变得简单。同时,review 中的讨论澄清了去重守卫的设计动机,帮助理解混合模型状态注册的潜在陷阱。建议后续跟进 get_mamba_state_buf_infos 的泛化改造。

讨论亮点

Review 中主要讨论了以下几点:

  • 类型风格一致:ShangmingCai 建议统一使用 Listlist 类型标注,ispobock 接受了建议。
  • 去重守卫逻辑:ShangmingCai 对 StateType.MAMBA not in 提出疑问,ispobock 解释这是为了防止重复追加(当 HybridLinearKVPool 分支已添加 MAMBA 时,后续通过 req_to_token_pool 的 fallback 不再重复添加)。
  • 保留注释建议:ShangmingCai 建议保留说明 DeepSeekV4 继承层次关系的注释,方便后续优化。
  • 进一步重构建议:ShangmingCai 提到应将 get_mamba_state_buf_infos 统一为 get_state_buf_infos,ispobock 未在本次 PR 中处理。

实现拆解

  1. 基础类型定义:在 disaggregation/base/conn.py 添加 StateType 枚举(NONEMAMBASWANSA),替换原有的字符串 state_type 字段,并在 KVArgs 中将相关字段改为 List[List[...]] 以支持多状态组件。
  2. 序列化工具:在 disaggregation/common/utils.py 新增 pack_list_of_buffers / unpack_list_of_bufferspack_int_lists / unpack_int_lists 四个函数,用于将嵌套列表封包成紧凑的 bytes 序列,供 ZMQ wire 传输。
  3. 状态注册与遍历:在 disaggregation/utils.py 新增 append_state_component 辅助函数,重写 setup_state_kv_args,使其根据 pool 类型顺序追加组件,而非一次性赋值。原先的单分支 if-elif 被拆分为内嵌的 payload 函数(_mamba_payload_swa_payload_nsa_payload),在发/收端由循环驱动调用。
  4. 后端适配:修改 NIXL(conn.py)和 Mooncake(conn.py)的 wire 数据结构,将 dst_state_indicesdst_state_data_ptrs 等从 List[int] 改为 List[List[int]],并使用新的 pack/unpack 处理;Ascend 和 Mori 连接层做相应字段映射。
  5. 测试覆盖:新增 test/registered/unit/disaggregation/test_disaggregation_wire.py,测试 pack_int_lists / unpack_int_lists 的往返正确性,包括空列表、嵌套列表、ndarray 输入等边界情况。
文件 模块 状态 重要度
python/sglang/srt/disaggregation/common/utils.py 序列化工具 modified 8.03
python/sglang/srt/disaggregation/prefill.py 预填充 modified 7.99
python/sglang/srt/disaggregation/decode.py 解码 modified 7.94
python/sglang/srt/disaggregation/nixl/conn.py NIXL 后端 modified 7.72
python/sglang/srt/disaggregation/mooncake/conn.py Mooncake 后端 modified 7.72
python/sglang/srt/disaggregation/utils.py 状态管理 modified 7.67
test/registered/unit/disaggregation/test_disaggregation_wire.py 解聚测试 added 7.36
python/sglang/srt/disaggregation/base/conn.py 基础接口 modified 6.85
python/sglang/srt/mem_cache/memory_pool.py 缓存池 modified 6.06
python/sglang/srt/disaggregation/mori/conn.py Mori 后端 modified 6.05
python/sglang/srt/disaggregation/common/conn.py 公共连接层 modified 5.95
python/sglang/srt/disaggregation/ascend/conn.py Ascend 后端 modified 5.52

关键符号

pack_list_of_buffers unpack_list_of_buffers pack_int_lists unpack_int_lists append_state_component setup_state_kv_args _mamba_payload _swa_payload _nsa_payload get_state_buf_infos get_state_dim_per_tensor StateType

关键源码片段

python/sglang/srt/disaggregation/prefill.py core-logic

预填充侧核心逻辑,将状态打包拆分为内嵌函数(_mamba_payload/_swa_payload/_nsa_payload),由 state_types 循环驱动。

# prefill.py — send_kv_chunk 方法内的状态负载内嵌函数
seq_len = len(req.fill_ids)def _mamba_payload():
    """从 req_to_token_pool 获取 Mamba 状态索引(单值列表)。"""
    return [self.req_to_token_pool
            .req_index_to_mamba_index_mapping[req.req_pool_idx]
            .cpu().numpy()]def _swa_payload():
    """计算滑动窗口 KV 索引并转换到 SWA pool,返回 page 索引。"""
    window_start = max(0, seq_len - window_size)
    window_start = (window_start // page_size) * page_size
    window_kv_indices_full = self.req_to_token_pool.req_to_token[
        req.req_pool_idx, window_start:seq_len]
    window_kv_indices_swa = self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
        window_kv_indices_full)
    return kv_to_page_indices(window_kv_indices_swa.cpu().numpy(), page_size)def _nsa_payload():
    """取完整前缀 KV 索引并转换为 page 索引。"""
    kv_indices_full = self.req_to_token_pool.req_to_token[
        req.req_pool_idx, :seq_len]
    return kv_to_page_indices(kv_indices_full.cpu().numpy(), page_size)
python/sglang/srt/disaggregation/utils.py core-logic

新增 append_state_component,重构 setup_state_kv_args 为循环追加组件,是状态注册的核心。

def setup_state_kv_args(
    kv_args: KVArgs,
    token_to_kv_pool,
    draft_token_to_kv_pool=None,
    req_to_token_pool=None,
) -> None:
    """Populate kv_args state-buffer fields from the given pool.
    Shared by prefill and decode bootstrap paths."""
    from sglang.srt.disaggregation.base.conn import StateType
    from sglang.srt.mem_cache.base_swa_memory_pool import BaseSWAKVPool
    from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, NSATokenToKVPool
​
    kv_args.state_types = []
    kv_args.state_data_ptrs = []
    kv_args.state_data_lens = []
    kv_args.state_item_lens = []
    kv_args.state_dim_per_tensor = []
​
    target = token_to_kv_pool
    # 主 KV 池的状态
    if hasattr(target, "get_state_buf_infos"):
        state_data_ptrs, state_data_lens, state_item_lens = target.get_state_buf_infos()
        if isinstance(target, BaseSWAKVPool):
            append_state_component(kv_args, StateType.SWA,
                                   state_data_ptrs, state_data_lens, state_item_lens)
        elif isinstance(target, HybridLinearKVPool):
            append_state_component(kv_args, StateType.MAMBA,
                                   state_data_ptrs, state_data_lens, state_item_lens,
                                   dim_per_tensor=target.get_state_dim_per_tensor()
                                   if hasattr(target, "get_state_dim_per_tensor") else None)
        elif isinstance(target, NSATokenToKVPool):
            append_state_component(kv_args, StateType.NSA,
                                   state_data_ptrs, state_data_lens, state_item_lens)
​
    # draft KV 池(若有)同样处理
    if draft_token_to_kv_pool is not None:
        draft = draft_token_to_kv_pool
        if hasattr(draft, "get_state_buf_infos"):
            # 类似分支,省略
            pass
​
    # 如果同时有 req_to_token_pool 上的 Mamba 状态且未在池中注册
    if req_to_token_pool is not None and StateType.MAMBA not in kv_args.state_types:
        data_ptrs, data_lens, item_lens = req_to_token_pool.get_mamba_state_buf_infos()
        append_state_component(kv_args, StateType.MAMBA,
                               data_ptrs, data_lens, item_lens)

评论区精华

统一 List 和 list 类型标注 style

ShangmingCai 建议在 decode.py 和 prefill.py 中统一使用 `List` 或 `list`,保持类型标注一致。

结论:ispobock 接受了建议并已在 commit 中修正。 · 已解决

去重守卫逻辑解释 设计

ShangmingCai 对 `StateType.MAMBA not in` 提出疑问,ispobock 解释:当 `HybridLinearKVPool` 分支已添加 MAMBA 时,后续通过 `req_to_token_pool` 的 fallback 通过 `not in` 避免重复追加。

结论:理解并接受,无代码改动。 · 已解决

保留 DeepSeekV4 注释 documentation

ShangmingCai 建议保留说明 DeepSeekV4TokenToKVPool 继承 BaseSWAKVPool 的注释,方便后续优化。

结论:ispobock 接受了建议。 · 已解决

重构 get_mamba_state_buf_infos 设计

ShangmingCai 建议将 `get_mamba_state_buf_infos` 重构为 `get_state_buf_infos`,统一接口。

结论:ispobock 未在本次 PR 中处理,可能后续跟进。 · unresolved

风险与影响

  1. 序列化兼容性:wire 格式从扁平的 int 列表变为嵌套的 int 列表,所有 PD 后端必须同步更新,否则解析错误。
  2. 状态顺序一致性:Prefill 和 Decode 端的 state_types 列表追加顺序必须一致;目前的实现依赖相同的 pool 类型判断逻辑,若模型配置组合特殊可能出现错位。
  3. 边界条件风险:新增的 pack/unpack 函数未在实际大吞吐场景下验证,空列表和超大列表的边缘情况可能暴漏性能或正确性问题。
  4. 零长度状态缓冲区:NIXL 后端增加 guard zero-length state buffers 提交,说明已有零长度保护,但仍需关注跨 TP 的 state 切片计算。

影响范围:所有使用 PD 解聚功能的场景,特别是混合注意力模型(如 DeepSeek-V4、Qwen3-Next 等)。对于纯 Attention 模型(MHA/GQA)不受影响,因为 state_types 长度为零。
影响程度:中等。虽然重构了核心传输路径,但功能和外部接口未变,仅内部数据表示改变。合并后需要各后端同步更新,同时测试验证(PR body 附件了 gsm8k 准确率保持数据)。
团队影响:增强了解聚层的可扩展性,后续添加新状态类型(如线性注意力)只需在枚举和新加入分支即可,降低维护成本。

序列化兼容 状态顺序依赖 零长度保护 跨后端同步

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论