Prhub

#27004 fix(disagg): correct DSA/SWA state-page transfer mismatch in PD disaggregation

原始 PR 作者 kflansburg 合并时间 2026-06-03 14:33 文件变更 3 提交数 2 评论 8 代码增减 +63 / -2

执行摘要

修复 PD 分离中 DSA/SWA 状态页传输不匹配

在 PD 分离部署的 DSA (NSA) 模型上,使用 mooncake 传输后端时,预填充端在发送最后一块时,fill_ids 已包含采样 token,导致状态页列表比解码端注册的长度多一页,进而触发 group_concurrent_contiguous 中 NumPy 广播形状不匹配崩溃。详细见 PR body:'On the last chunk, the prefill side enumerates its DSA/SWA state-page list over seq_len = len(req.fill_ids). By send time fill_ids already includes the token sampled during prefill... But the decode side registers its destination state pages over len(origin_input_ids)'。

建议阅读者精读该 PR,了解如何通过边界防御和长度限制修正复杂的分布式传输 bug。group_concurrent_contiguous 的防御性设计思路值得参考。对于 DSA 模型分离部署团队,应尽快合并。

讨论亮点

在 review 中,gemini-code-assist[bot] 提出了两个意见:

  • 关于空目标数组的返回行为:在 group_concurrent_contiguous 中,当 src_indices 非空但 dst_indices 为空时,PR 实现返回 [], []。评审认为这会静默丢弃待传输的状态,导致 decode 端状态缺失,建议改为抛出 ValueError。PR 维持了返回空的设计,未采纳该意见。
  • 对应的测试用例:评审建议将 test_nonempty_src_empty_dst 测试改为断言抛出错误,PR 同样未修改,保留了当前行为。
    最终,项目成员 ShangmingCai 批准了该 PR。设计决策倾向于认为空 dst 意味着没有页面需要注册,返回空是合理的安全规避。

实现拆解

步骤

  1. 防御 group_concurrent_contiguous 函数 (python/sglang/srt/disaggregation/common/utils.py):

    • 当 src_indices 或 dst_indices 任一为空时,立即返回空列表,而不是只在 src 空时返回。
    • 当两者均非空但长度不一时,抛出 ValueError 提供明确错误信息。
      这是对之前仅检查 src 为空的补充,避免仅 dst 为空时发生 NumPy 广播错误。
  2. 修正预填充状态页序列长度 (python/sglang/srt/disaggregation/prefill.py):

    • send_kv_chunk 方法中,将 seq_len 的计算从 len(req.fill_ids) 改为 min(len(req.fill_ids), len(req.origin_input_ids))
    • 这保证状态页枚举范围与主池传输范围(已由 end_idx 限制)一致,不会因为采样 token 多出一页。
  3. 添加单元测试 (test/registered/unit/disaggregation/test_disaggregation_wire.py):
    • 新增 TestGroupConcurrentContiguous 类,包含 6 个测试:
      • test_single_contiguous_group:正常连续分组
      • test_splits_on_discontiguous_indices:非连续分割
      • test_both_empty:双方空
      • test_empty_src_nonempty_dst:源空目标非空(返回空)
      • test_nonempty_src_empty_dst:源非空目标空(回归测试,返回空而非崩溃)
      • test_mismatched_nonempty_lengths_raise:长度不等时抛出 ValueError
文件 模块 状态 重要度
python/sglang/srt/disaggregation/common/utils.py 分离传输 modified 5.97
python/sglang/srt/disaggregation/prefill.py 分离传输 modified 5.67
test/registered/unit/disaggregation/test_disaggregation_wire.py 分离测试 modified 6.25

关键符号

group_concurrent_contiguous send_kv_chunk

关键源码片段

python/sglang/srt/disaggregation/common/utils.py core-logic

核心防御逻辑:修复 group_concurrent_contiguous 函数,增加空数组守卫和长度校验,避免 NumPy 广播错误和静默误分组。

def group_concurrent_contiguous(
    src_indices: npt.NDArray[np.int32], dst_indices: npt.NDArray[np.int32]
) -> Tuple[List[npt.NDArray[np.int32]], List[npt.NDArray[np.int32]]]:
    """Vectorised NumPy implementation."""
    # src/dst indices are transferred pairwise, so an empty side means there is
    # nothing to transfer. Guarding both sides (not just src) avoids a cryptic
    # NumPy broadcast error from np.diff() below when only one side is empty, e.g.
    # a non-empty prefill DSA/SWA state list paired with an empty decode registration.
    if src_indices.size == 0 or dst_indices.size == 0:
        return [], []
​
    if src_indices.size != dst_indices.size:
        raise ValueError(
            "group_concurrent_contiguous requires equal-length src/dst index arrays, "
            f"got {src_indices.size} and {dst_indices.size}"
        )
​
    brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
    src_groups = np.split(src_indices, brk)
    dst_groups = np.split(dst_indices, brk)
​
    src_groups = [g.tolist() for g in src_groups]
    dst_groups = [g.tolist() for g in dst_groups]
​
    return src_groups, dst_groups
python/sglang/srt/disaggregation/prefill.py core-logic

根因修复:修改 send_kv_chunk 中 seq_len 计算,确保状态页枚举范围与解码端注册长度一致。

        if last_chunk:
            self.disagg_metadata_buffers.set_buf(req)
​
            # fill_ids includes the token sampled during prefill, but decode
            # registers state pages over origin_input_ids (DecodePreallocQueue)
            # and the main pool send is clamped to end_idx above. Matching that
            # length here avoids emitting an extra state page when the sampled
            # token crosses a page boundary, which mismatched src/dst lengths in
            # group_concurrent_contiguous.
            seq_len = min(len(req.fill_ids), len(req.origin_input_ids))
​
            def _mamba_payload():
                return [
                    self.req_to_token_pool.req_index_to_mamba_index_mapping[
                        req.req_pool_idx
                    ]
                    .cpu()
                    .numpy()
                ]
​
            def _swa_payload():
                window_size = self.sliding_window_size
                window_start = max(0, seq_len - window_size)
                window_start = (window_start // page_size) * page_size
                window_kv_indices_full = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, window_start:seq_len
                ]
                window_kv_indices_swa = (
                    self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
                        window_kv_indices_full
                    )
                )
                return kv_to_page_indices(
                    window_kv_indices_swa.cpu().numpy(), page_size
                )
​
            def _dsa_payload():
                kv_indices_full = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, :seq_len
                ]
                return kv_to_page_indices(kv_indices_full.cpu().numpy(), page_size)
test/registered/unit/disaggregation/test_disaggregation_wire.py test-coverage

新增单元测试,覆盖 group_concurrent_contiguous 的所有边界场景,确保回归防护。

class TestGroupConcurrentContiguous(unittest.TestCase):
    @staticmethod
    def _arr(values):
        return np.array(values, dtype=np.int32)
​
    def test_single_contiguous_group(self):
        src = self._arr([10, 11, 12])
        dst = self._arr([5, 6, 7])
        self.assertEqual(
            group_concurrent_contiguous(src, dst),
            ([[10, 11, 12]], [[5, 6, 7]]),
        )
​
    def test_splits_on_discontiguous_indices(self):
        src = self._arr([10, 11, 20])
        dst = self._arr([5, 6, 7])
        self.assertEqual(
            group_concurrent_contiguous(src, dst),
            ([[10, 11], [20]], [[5, 6], [7]]),
        )
​
    def test_both_empty(self):
        self.assertEqual(
            group_concurrent_contiguous(self._arr([]), self._arr([])), ([], [])
        )
​
    def test_empty_src_nonempty_dst(self):
        self.assertEqual(
            group_concurrent_contiguous(self._arr([]), self._arr([1, 2])), ([], [])
        )
​
    def test_nonempty_src_empty_dst(self):
        # Regression: a non-empty source paired with an empty destination must not
        # raise a NumPy broadcast error (observed transferring DSA sparse-attention
        # state on a disaggregated GLM deployment when decode registered zero dst indices).
        self.assertEqual(
            group_concurrent_contiguous(self._arr([1, 2]), self._arr([])), ([], [])
        )
​
    def test_mismatched_nonempty_lengths_raise(self):
        with self.assertRaises(ValueError):
            group_concurrent_contiguous(self._arr([1, 2, 3]), self._arr([1, 2]))

评论区精华

空目标数组返回行为的合理性 设计

gemini-code-assist[bot] 指出当 src_indices 非空但 dst_indices 为空时,静默返回空列表可能导致状态传输遗漏,建议改为抛出 ValueError。

结论:PR 维持返回空列表的设计,未采纳抛出异常的建议。 · 其他意见未采纳,PR 已合并

相应测试用例应改为断言异常 测试

gemini-code-assist[bot] 建议将 test_nonempty_src_empty_dst 测试改为断言抛出 ValueError。

结论:PR 保留了原测试,未修改。 · 其他意见未采纳,PR 已合并

风险与影响

主要风险:

  • 静默状态遗漏:当 dst_indices 为空而 src_indices 非空时,状态页不会被传输,可能导致 decode 端状态缺失。当前设计假设这种情况不会发生(因为解码端总是注册了相同数量的页面),但若出现不一致,不会收到错误信号。
  • 防御过度:增加了运行时长度校验,但代价很小。
  • 性能影响:仅增加两个大小检查和一次比较,无显著影响。
    总体风险较低,但 reviewer 指出的空 dst 情况需关注。

影响范围:仅限于使用 DSA/SWA 稀疏注意力的模型(如 GLM-DSA-FP8)且在 PD 分离部署下的用户。修复了预填充转解码时的崩溃,确保状态页面正确传输。无模型向前兼容问题,无性能退化。测试覆盖了核心边界。

空 dst 静默忽略 长度校验保护 防御编程提升稳健性

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论