执行摘要
- 一句话:修复 PD 分离中 DSA/SWA 状态页传输不匹配
- 推荐动作:建议阅读者精读该 PR,了解如何通过边界防御和长度限制修正复杂的分布式传输 bug。
group_concurrent_contiguous 的防御性设计思路值得参考。对于 DSA 模型分离部署团队,应尽快合并。
功能与动机
在 PD 分离部署的 DSA (NSA) 模型上,使用 mooncake 传输后端时,预填充端在发送最后一块时,fill_ids 已包含采样 token,导致状态页列表比解码端注册的长度多一页,进而触发 group_concurrent_contiguous 中 NumPy 广播形状不匹配崩溃。详细见 PR body:'On the last chunk, the prefill side enumerates its DSA/SWA state-page list over seq_len = len(req.fill_ids). By send time fill_ids already includes the token sampled during prefill... But the decode side registers its destination state pages over len(origin_input_ids)'。
实现拆解
步骤
-
防御 group_concurrent_contiguous 函数 (python/sglang/srt/disaggregation/common/utils.py):
- 当 src_indices 或 dst_indices 任一为空时,立即返回空列表,而不是只在 src 空时返回。
- 当两者均非空但长度不一时,抛出
ValueError 提供明确错误信息。
这是对之前仅检查 src 为空的补充,避免仅 dst 为空时发生 NumPy 广播错误。
-
修正预填充状态页序列长度 (python/sglang/srt/disaggregation/prefill.py):
- 在
send_kv_chunk 方法中,将 seq_len 的计算从 len(req.fill_ids) 改为 min(len(req.fill_ids), len(req.origin_input_ids))。
- 这保证状态页枚举范围与主池传输范围(已由
end_idx 限制)一致,不会因为采样 token 多出一页。
- 添加单元测试 (
test/registered/unit/disaggregation/test_disaggregation_wire.py):
- 新增
TestGroupConcurrentContiguous 类,包含 6 个测试:
test_single_contiguous_group:正常连续分组
test_splits_on_discontiguous_indices:非连续分割
test_both_empty:双方空
test_empty_src_nonempty_dst:源空目标非空(返回空)
test_nonempty_src_empty_dst:源非空目标空(回归测试,返回空而非崩溃)
test_mismatched_nonempty_lengths_raise:长度不等时抛出 ValueError
关键文件:
python/sglang/srt/disaggregation/common/utils.py(模块 分离传输;类别 source;类型 core-logic;符号 group_concurrent_contiguous): 核心防御逻辑:修复 group_concurrent_contiguous 函数,增加空数组守卫和长度校验,避免 NumPy 广播错误和静默误分组。
python/sglang/srt/disaggregation/prefill.py(模块 分离传输;类别 source;类型 core-logic;符号 send_kv_chunk): 根因修复:修改 send_kv_chunk 中 seq_len 计算,确保状态页枚举范围与解码端注册长度一致。
test/registered/unit/disaggregation/test_disaggregation_wire.py(模块 分离测试;类别 test;类型 test-coverage;符号 TestGroupConcurrentContiguous, _arr, test_single_contiguous_group, test_splits_on_discontiguous_indices): 新增单元测试,覆盖 group_concurrent_contiguous 的所有边界场景,确保回归防护。
关键符号:group_concurrent_contiguous, send_kv_chunk
关键源码片段
python/sglang/srt/disaggregation/common/utils.py
核心防御逻辑:修复 group_concurrent_contiguous 函数,增加空数组守卫和长度校验,避免 NumPy 广播错误和静默误分组。
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int32], dst_indices: npt.NDArray[np.int32]
) -> Tuple[List[npt.NDArray[np.int32]], List[npt.NDArray[np.int32]]]:
"""Vectorised NumPy implementation."""
# src/dst indices are transferred pairwise, so an empty side means there is
# nothing to transfer. Guarding both sides (not just src) avoids a cryptic
# NumPy broadcast error from np.diff() below when only one side is empty, e.g.
# a non-empty prefill DSA/SWA state list paired with an empty decode registration.
if src_indices.size == 0 or dst_indices.size == 0:
return [], []
if src_indices.size != dst_indices.size:
raise ValueError(
"group_concurrent_contiguous requires equal-length src/dst index arrays, "
f"got {src_indices.size} and {dst_indices.size}"
)
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups
python/sglang/srt/disaggregation/prefill.py
根因修复:修改 send_kv_chunk 中 seq_len 计算,确保状态页枚举范围与解码端注册长度一致。
if last_chunk:
self.disagg_metadata_buffers.set_buf(req)
# fill_ids includes the token sampled during prefill, but decode
# registers state pages over origin_input_ids (DecodePreallocQueue)
# and the main pool send is clamped to end_idx above. Matching that
# length here avoids emitting an extra state page when the sampled
# token crosses a page boundary, which mismatched src/dst lengths in
# group_concurrent_contiguous.
seq_len = min(len(req.fill_ids), len(req.origin_input_ids))
def _mamba_payload():
return [
self.req_to_token_pool.req_index_to_mamba_index_mapping[
req.req_pool_idx
]
.cpu()
.numpy()
]
def _swa_payload():
window_size = self.sliding_window_size
window_start = max(0, seq_len - window_size)
window_start = (window_start // page_size) * page_size
window_kv_indices_full = self.req_to_token_pool.req_to_token[
req.req_pool_idx, window_start:seq_len
]
window_kv_indices_swa = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices_full
)
)
return kv_to_page_indices(
window_kv_indices_swa.cpu().numpy(), page_size
)
def _dsa_payload():
kv_indices_full = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :seq_len
]
return kv_to_page_indices(kv_indices_full.cpu().numpy(), page_size)
test/registered/unit/disaggregation/test_disaggregation_wire.py
新增单元测试,覆盖 group_concurrent_contiguous 的所有边界场景,确保回归防护。
class TestGroupConcurrentContiguous(unittest.TestCase):
@staticmethod
def _arr(values):
return np.array(values, dtype=np.int32)
def test_single_contiguous_group(self):
src = self._arr([10, 11, 12])
dst = self._arr([5, 6, 7])
self.assertEqual(
group_concurrent_contiguous(src, dst),
([[10, 11, 12]], [[5, 6, 7]]),
)
def test_splits_on_discontiguous_indices(self):
src = self._arr([10, 11, 20])
dst = self._arr([5, 6, 7])
self.assertEqual(
group_concurrent_contiguous(src, dst),
([[10, 11], [20]], [[5, 6], [7]]),
)
def test_both_empty(self):
self.assertEqual(
group_concurrent_contiguous(self._arr([]), self._arr([])), ([], [])
)
def test_empty_src_nonempty_dst(self):
self.assertEqual(
group_concurrent_contiguous(self._arr([]), self._arr([1, 2])), ([], [])
)
def test_nonempty_src_empty_dst(self):
# Regression: a non-empty source paired with an empty destination must not
# raise a NumPy broadcast error (observed transferring DSA sparse-attention
# state on a disaggregated GLM deployment when decode registered zero dst indices).
self.assertEqual(
group_concurrent_contiguous(self._arr([1, 2]), self._arr([])), ([], [])
)
def test_mismatched_nonempty_lengths_raise(self):
with self.assertRaises(ValueError):
group_concurrent_contiguous(self._arr([1, 2, 3]), self._arr([1, 2]))
评论区精华
在 review 中,gemini-code-assist[bot] 提出了两个意见:
风险与影响
关联脉络
- PR #27011 [Bugfix] Clean up failed NIXL sender state: 同属 disaggregation 传输层修复,都涉及状态清理和崩溃修复。
参与讨论