执行摘要
- 一句话:将 Mamba 状态操作迁移到 forward stream 消除调度竞争
- 推荐动作:该 PR 值得所有关注并发调度和 Mamba 模型的开发者精读,其“捐赠模式”和“延迟操作到 forward stream”是处理调度器与前向流之间竞争的有效模式。建议合并后关注 HiCache 兼容性修复和 review 中提到的 GPU→CPU 同步优化。
功能与动机
关联 Issue #24221 描述了在 overlap scheduler 中,Mamba radix cache 可能在前向 pass 尚未完成写入时快照 Mamba 临时状态,导致 radix cache 条目损坏。具体路径是 chunked prefill 的 stash 操作没有等待 copy_done 同步。本 PR 通过将所有 Mamba 状态操作移到 forward stream,从根本上消除竞争窗口。
实现拆解
- 将
MambaPool.alloc 中的立即清零操作拆分为独立的 clear_slots 方法,原 alloc 不再执行清零,让调度器可以分配插槽而不触发 GPU 操作。
- 在
cache_unfinished_req 中,将来自 req 的 Mamba 状态索引“捐赠”给 radix cache(即直接转移所有权),然后为请求分配新的空白插槽;替换原来的 fork_from 复制模式,避免在读缓存时与 forward stream 产生数据竞争。
- 在 prefix match 的 COW 路径中,不在 scheduler stream 上立即复制,而是将源索引记录到
req.mamba_cow_src_index 和 req.mamba_needs_clear,实际复制/清零延迟到 forward stream 的 init_forward_metadata 阶段。
- 在
hybrid_linear_attn_backend(以及 Mamba2 后端)的 init_forward_metadata 中调用新方法 _execute_deferred_mamba_cow_and_clear,执行之前收集到的清除和复制操作;同时利用 is_draft_worker 跳过推测解码的 draft 阶段,防止重复执行。
- 提取公共辅助函数
set_mamba_track_indices_from_reqs 以减少 eagle_info_v2 等位置对 Mamba 追踪索引的重复计算;统一 copy_from 接口支持批量索引;删除不再使用的 fork_from 方法。
关键文件:
python/sglang/srt/managers/schedule_batch.py(模块 调度器;类别 source;类型 core-logic;符号 set_mamba_track_indices_from_reqs, _collect_deferred_mamba_cow_and_clear, prepare_for_split_prefill): 核心调度批次类,添加延迟 COW/clear 字段,新增收集和提取辅助函数,是 PR 主要逻辑入口
python/sglang/srt/mem_cache/memory_pool.py(模块 缓存池;类别 source;类型 core-logic;符号 clear_slots, copy_from, fork_from): 将 alloc 中的清零独立成 clear_slots,删除 fork_from,修改 copy_from 支持批量操作,为延迟执行奠定基础
python/sglang/srt/mem_cache/mamba_radix_cache.py(模块 缓存树;类别 source;类型 core-logic;符号 _alloc_mamba_slot): cache_unfinished_req 采用捐赠模式替代 fork 复制,避免与 forward stream 竞争
python/sglang/srt/mem_cache/unified_cache_components/mamba_component.py(模块 缓存层;类别 source;类型 core-logic;符号 _alloc_mamba_slot): 适配统一缓存的 Mamba 组件,采用捐赠模式并添加 _alloc_mamba_slot
python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py(模块 注意力后端;类别 source;类型 core-logic;符号 _execute_deferred_mamba_cow_and_clear): 添加 _execute_deferred_mamba_cow_and_clear 在 forward stream 上执行延迟操作
python/sglang/srt/model_executor/forward_batch_info.py(模块 前向元数据;类别 source;类型 data-contract): 为 ForwardBatch 添加延迟操作字段,连接调度和 forward stream
python/sglang/srt/speculative/eagle_info_v2.py(模块 推测解码;类别 source;类型 dependency-wiring): 使用公共辅助函数 set_mamba_track_indices_from_reqs,消除重复代码
python/sglang/srt/mem_cache/hi_mamba_radix_cache.py(模块 缓存树;类别 source;类型 core-logic): 适配 HiCache 的捐赠模式变更
关键符号:set_mamba_track_indices_from_reqs, _collect_deferred_mamba_cow_and_clear, clear_slots, copy_from, _alloc_mamba_slot, _execute_deferred_mamba_cow_and_clear, prepare_for_split_prefill
关键源码片段
python/sglang/srt/managers/schedule_batch.py
核心调度批次类,添加延迟 COW/clear 字段,新增收集和提取辅助函数,是 PR 主要逻辑入口
def set_mamba_track_indices_from_reqs(batch):
"""从请求对象构建 mamba_track_indices(权威来源)。
避免之前在 eagle_info 中的重复实现,并确保索引来源一致。
"""
req_to_token_pool = batch.req_to_token_pool
# 获取所有请求的 ping-pong 映射缓冲区,形状 (bs, ping_pong_size)
all_buffers = req_to_token_pool.req_index_to_mamba_ping_pong_track_buffer_mapping[
batch.req_pool_indices
] # shape: (bs, ping_pong_size), int64, on device
# 从每个请求的 mamba_next_track_idx 构建列索引
idx = (
torch.tensor(
[req.mamba_next_track_idx for req in batch.reqs],
dtype=torch.int64,
pin_memory=True,
)
.unsqueeze(1) # shape: (bs, 1)
.to(device=all_buffers.device, non_blocking=True)
)
# 通过 gather 取出正确的跟踪索引,结果形状 (bs,)
batch.mamba_track_indices = (
torch.gather(all_buffers, 1, idx).squeeze(1).to(torch.int64)
)
python/sglang/srt/mem_cache/memory_pool.py
将 alloc 中的清零独立成 clear_slots,删除 fork_from,修改 copy_from 支持批量操作,为延迟执行奠定基础
def clear_slots(self, indices: torch.Tensor):
"""Zero out mamba state at the given pool indices. 必须在 forward stream 上执行。"""
need_size = len(indices)
# 清除 conv state
for i in range(len(self.mamba_cache.conv)):
t = self.mamba_cache.conv[i]
# 扩展零张量以适应多个插槽,避免 CPU-GPU 同步
z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
t.shape[0], need_size, *t.shape[2:]
)
t[:, indices] = z
# 清除 temporal state
t = self.mamba_cache.temporal
z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
t.shape[0], need_size, *t.shape[2:]
)
t[:, indices] = z
def copy_from(self, src_indices: torch.Tensor, dst_indices: torch.Tensor):
"""从源索引复制 mamba 状态到目标索引。在 forward stream 上执行。"""
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, dst_indices] = self.mamba_cache.conv[i][
:, src_indices
]
self.mamba_cache.temporal[:, dst_indices] = self.mamba_cache.temporal[
:, src_indices
]
评论区精华
风险与影响
关联脉络
参与讨论