Prhub

#22536 [Disagg][NIXL] Add staging buffer support for heterogeneous TP KV transfer

原始 PR 作者 YAMY1234 合并时间 2026-05-13 19:54 文件变更 4 提交数 25 评论 20 代码增减 +820 / -146

执行摘要

NIXL 后端新增 staging buffer 支持异构 TP 批量 KV 传输

NIXL disaggregated serving currently requires prefill and decode to use the same TP layout. When prefill uses TP4 and decode uses DEP4 (DP4+TP4+EP4), each prefill rank's KV cache must be split and sent to multiple decode ranks. Without staging buffers, the prefill side must issue prefill_tp × decode_tp separate RDMA transfers per chunk, saturating the RDMA descriptor table and adding significant per-transfer overhead. The staging buffer approach (already implemented for mooncake in #19890) consolidates KV heads into a contiguous staging region on prefill, issues a single bulk RDMA transfer per rank pair, and lets the decode side scatter from the staging buffer into the final KV cache pages asynchronously.

此 PR 重构了 NIXL 的 KV 传输路径,引入 staging buffer 后性能提升显著。代码抽象值得学习,尤其是 register_fn 回调模式。建议 NIXL 用户升级并启用 staging buffer。mooncake 用户无需改动,但可从统一抽象中受益。

讨论亮点

本 PR review 中重点讨论了以下内容:

  • 重复导入:ShangmingCai 指出 KVArgsRegisterInfo.from_zmq 中重复导入 StagingRegisterInfo,YAMY1234 已移除冗余导入。
  • 字段位置注释:要求为 staging 字段添加注释说明其位置须保持在最后以确保 zmq 解析的向后兼容性。
  • 预取字典清理:ShangmingCai 担忧长时间运行后预取字典膨胀,YAMY1234 确认已在 transfer 状态变为 Success 时清理相关条目。
  • 通知标签解析:gemini-code-assist 建议使用 split("_", 8) 限制分割次数以避免 agent_name 中的下划线导致错误,作者已应用此修复。
  • staging 降级重试:gemini-code-assist 指出当 staging 分配未就绪时立即降级为分片传输,缺乏重试机制,可能影响性能。这一建议未被完全采纳,作为未来改进方向。

实现拆解

  1. 共享 staging 生命周期提取:将 mooncake 连接器中的 handle_watermark_msghandle_staging_rspinit_staging_buffersinit_staging_allocator 抽取到 common/staging_handler.py,并设计 register_fn 回调使后端的注册行为可通过 lambda 注入,实现 NIXL 和 mooncake 的统一。
  2. NIXL staging 上下文初始化:在 nixl/conn.pyNixlKVSenderNixlKVReceiver 中分别初始化 prefill 和 decode 侧的 staging 上下文(_init_staging_prefill_ctx/_init_staging_decode_ctx),为每个 worker 线程创建独立的 staging buffer,避免 cross-worker 竞争。
  3. NIXL 特定 staging 传输:新增 send_kvcache_staged 方法,收集所有 KV 层到一个连续暂存 buffer 后发起单次 RDMA 写;进度通知通过 RDMA tag 嵌入 chunk 索引、page offset 等元数据,decode 端主线程 poll 路径解析 tag 并调用 handle_chunk_arrived 触发 scatter。
  4. mooncake 后端适配mooncake/conn.py 中原有的 WATERMARK 和 STAGING_RSP 处理逻辑替换为调用共享函数,降低重复代码约 57 行。
  5. 配置验证:在 server_args.py 中将允许启用 staging buffer 的后端从仅 mooncake 扩展为 mooncake 或 nixl。
文件 模块 状态 重要度
python/sglang/srt/disaggregation/nixl/conn.py NIXL 连接器 modified 8.84
python/sglang/srt/disaggregation/common/staging_handler.py 暂存处理 modified 8.48
python/sglang/srt/disaggregation/mooncake/conn.py Mooncake 连接器 modified 6.78
python/sglang/srt/server_args.py 服务配置 modified 4.5

关键符号

_init_staging_prefill_ctx _init_staging_decode_ctx _init_staging_buffers _init_staging_allocator _register_staging_memory set_kv_buffer_tensors register_staging_room_bootstrap _is_watermark_ready handle_chunk_arrived handle_watermark_msg handle_staging_rsp init_staging_buffers init_staging_allocator send_kvcache_staged _handle_stg_notification

关键源码片段

python/sglang/srt/disaggregation/nixl/conn.py dependency-wiring

核心实现文件,包含 NIXL staging 缓冲区初始化、RDMA 通知处理和 staging 传输函数。

# file: python/sglang/srt/disaggregation/nixl/conn.py# 在 KVArgsRegisterInfo 末尾添加 staging 字段,保持字段顺序稳定
@dataclasses.dataclass
class KVArgsRegisterInfo:
    room: str
    endpoint: str
    # ... 其他字段 ...
    dst_kv_item_len: int
    dst_state_item_lens: List[List[int]] = dataclasses.field(default_factory=list)
    dst_state_dim_per_tensor: List[List[int]] = dataclasses.field(default_factory=list)
    # 新增 staging 字段:可选,由 ZMQ 帧的可变长尾部解析,
    # 必须保持在最后以保证位置构造稳定
    staging: Optional["StagingRegisterInfo"] = None
​
    @classmethod
    def from_zmq(cls, msg: List[bytes]):
        # ... 解析前 13 个字段 ...
        # 从帧的索引 14 开始解析 staging 信息(如果存在)
        return cls(
            # ... 基础字段赋值 ...
            staging=StagingRegisterInfo.from_zmq_fields(msg, 14),
        )# 在 NixlKVSender.__init__ 中初始化 staging 预取上下文
if self.disaggregation_mode == DisaggregationMode.PREFILL:
    if self.enable_staging:
        self._init_staging_prefill_ctx()
        self._init_staging_buffers(len(self.transfer_queues))
    for i, queue in enumerate(self.transfer_queues):
        # 每个 worker 获得独立的 staging buffer,避免竞争
        staging_buffer = (
            self._staging_ctx.buffers[i]
            if self.enable_staging and self._staging_ctx.buffers
            else None
        )
        threading.Thread(
            target=self.transfer_worker,
            args=(queue, staging_buffer),
            daemon=True,
        ).start()
python/sglang/srt/disaggregation/common/staging_handler.py entrypoint

共享 staging 生命周期管理,提取后通用函数供 NIXL 和 mooncake 使用。

# file: python/sglang/srt/disaggregation/common/staging_handler.py# 共享的 chunk 到达处理:NIXL 通过 RDMA 通知调用此方法,
# mooncake 通过 ZMQ CHUNK_READY 调用
def handle_chunk_arrived(
    self,
    room: int,
    chunk_idx: int,
    page_start: int,
    num_pages: int,
    writer_id: str,
    chunk_writer_counts: dict,
) -> bool:
    """累加 writer 到达数,当所有 writer 报告完毕时提交 scatter"""
    chunk_writer_counts[room][chunk_idx].append(
        (page_start, num_pages, writer_id)
    )
    decode_req = self._room_to_decode_req.get(room)
    if decode_req is None:
        logger.warning(
            "Staging chunk arrived for unregistered room=%s chunk=%d, skipping",
            room, chunk_idx,
        )
        return False
    writers_arrived = len(chunk_writer_counts[room][chunk_idx])
    num_writers = self.num_writers_for(decode_req)
    if writers_arrived >= num_writers:
        self.submit_chunk_scatter(room, chunk_idx, page_start, num_pages)
        del chunk_writer_counts[room][chunk_idx]
        return True
    return False# 共享的 watermark 消息处理:更新远程 watermark 并通知等待线程
def handle_watermark_msg(staging_ctx, msg_parts) -> None:
    wm_round = int(msg_parts[1].decode("ascii"))
    wm_tail = int(msg_parts[2].decode("ascii"))
    wm_session = msg_parts[3].decode("ascii") if len(msg_parts) > 3 else ""
    with staging_ctx.watermark_cv:
        prev = staging_ctx.remote_watermarks.get(wm_session, (0, 0))
        if (wm_round, wm_tail) > prev:
            staging_ctx.remote_watermarks[wm_session] = (wm_round, wm_tail)
            staging_ctx.watermark_cv.notify_all()

评论区精华

KVArgsRegisterInfo 中 staging 字段重复导入 style

ShangmingCai 指出 staging_handler.py 的 StagingRegisterInfo 在 from_zmq 方法中被重复导入,建议移除冗余导入。

结论:YAMY1234 已修复,移除了重复的 import。 · 已解决

staging 字段应加注释说明位置 documentation

ShangmingCai 请求为 KVArgsRegisterInfo 的 staging 字段添加注释,说明其位置必须保持最后以确保 ZMQ 解析兼容。

结论:作者已添加 inline 注释。 · 已解决

预取字典可能长期运行后膨胀 性能

ShangmingCai 提出 `prefetch_requested` 字典可能无限增长,需要清理机制。

结论:YAMY1234 回复在 transfer_worker 中传输完成后清理相关条目。 · 已解决

RDMA 通知标签解析鲁棒性 正确性

gemini-code-assist 建议使用 split("_", 8) 限制分割次数,避免 agent_name 含下划线时解析错误。

结论:作者已应用该建议。 · 已解决

staging fallback 机制缺失重试 设计

gemini-code-assist 指出当 staging 分配未就绪时立即降级为分片传输,建议像 mooncake 一样重试。

结论:未在本次 PR 中解决,可能作为未来改进。 · unresolved

风险与影响

1) NIXL 的 RDMA 通知处理在主线程上进行,可能增加调度器延迟(NIXL C++ 线程安全约束)。
2) staging buffer 占用额外 GPU 显存,需合理配置 SGLANG_DISAGG_STAGING_BUFFER 和环境变量。
3) 水印协调逻辑的并发正确性依赖 CPython GIL,未来迁移到自由线程可能需修改。
4) 回退到分片传输的逻辑在 staging 分配未就绪时立即降级,可考虑重试机制(评论已指出)。

影响 NIXL 后端用户,特别是使用异构 TP(如 TP4 prefill + DEP4 decode)的部署。在高并发下显著提升吞吐量和降低首词延迟。mooncake 用户受益于代码精简重构。对系统无破坏性变化,默认不启用(需设置环境变量 SGLANG_DISAGG_STAGING_BUFFER=1)。

RDMA 通知主线程处理 staging 分配未就绪降级 预取字典内存增长

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论