Prhub

#23539 [Bug Fix] missing index/KV transfer for MTP layer in NSA disaggregation

原始 PR 作者 zRzRzRzRzRzRzR 合并时间 2026-04-30 11:55 文件变更 2 提交数 8 评论 14 代码增减 +24 / -0

执行摘要

[NSA PD] 修复 MTP 层 draft 模型状态未传输

引用 PR body: 'In PD disaggregation with NSA + MTP, only the target model's NSA state buffers are registered for transfer. The draft model's NSATokenToKVPool buffers are never appended to kv_args, so the MTP layer's index/KV state is not sent from prefill to decode, causing wrong speculative decoding results.'

建议合并。该修复针对性强,逻辑简洁,review 后无争议。团队后续可考虑为其他状态池(如 SWA、Mamba)做类似扩展,确保统一覆盖。

讨论亮点

Review 中 kpham-sgl 指出两个问题:

1) 是否必须确保 draft 池也是 NSATokenToKVPool?2) 是否需要先检查 draft 池是否有 get_state_buf_infos 方法?ShangmingCai 进一步询问是否需要考虑非 MTP 的 spec decode 场景。作者 zRzRzRzRzRzRzR 回应:isinstance 检查确保只对 NSA draft 池生效,非 NSA 场景不受影响;当前 MTP-on-NSA 路径触发此修复,未来若出现非 MTP spec decode 使用 NSA draft,相同逻辑也正确适用。关于第二个问题,get_state_buf_infos 是 NSATokenToKVPool 类的方法,isinstance 已经保证其存在,无需额外 hasattr 检查。

实现拆解

实现分为三步:

  1. 定位问题:发现 kv_args.state_data_ptrs 只包含目标模型的 NSA 状态,draft 模型的状态未包含。
  2. decode.pyDecodePreallocQueue._init_kv_manager 中,于 isinstance(self.token_to_kv_pool, NSATokenToKVPool) 分支内添加对 self.draft_token_to_kv_pool 的检查,若其存在且同为 NSATokenToKVPool,则调用其 get_state_buf_infos() 方法获取指针/长度列表,并通过 += 追加到 kv_args 的对应列表。
  3. prefill.pyPrefillBootstrapQueue._init_kv_manager 中做完全相同的修改,保持 prefill 和 decode 侧的一致性。
    本修复未引入新配置或测试,仅改动核心数据路径。
文件 模块 状态 重要度
python/sglang/srt/disaggregation/decode.py PD 分离 modified 5.8
python/sglang/srt/disaggregation/prefill.py PD 分离 modified 5.8

关键符号

DecodePreallocQueue._init_kv_manager PrefillBootstrapQueue._init_kv_manager

关键源码片段

python/sglang/srt/disaggregation/decode.py core-logic

该文件是 Decode 节点初始化 KV 管理器的入口,修改 `DecodePreallocQueue._init_kv_manager` 方法,在主池为 NSA 时追加 draft 池的状态信息。

# DecodePreallocQueue._init_kv_manager (partial)
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
    state_data_ptrs, state_data_lens, state_item_lens = (
        self.token_to_kv_pool.get_state_buf_infos()
    )
    kv_args.state_data_ptrs = state_data_ptrs
    kv_args.state_data_lens = state_data_lens
    kv_args.state_item_lens = state_item_lens
​
    if isinstance(self.token_to_kv_pool, SWAKVPool):
        kv_args.state_type = "swa"
    elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
        kv_args.state_type = "mamba"
        if hasattr(self.token_to_kv_pool, "get_state_dim_per_tensor"):
            kv_args.state_dim_per_tensor = self.token_to_kv_pool.get_state_dim_per_tensor()
    elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
        kv_args.state_type = "nsa"
        # 修复:将 draft 模型(MTP 模块)的 NSA 状态缓冲区也加入传输列表
        if self.draft_token_to_kv_pool is not None and isinstance(
            self.draft_token_to_kv_pool, NSATokenToKVPool
        ):
            (draft_state_data_ptrs, draft_state_data_lens, draft_state_item_lens) = (
                self.draft_token_to_kv_pool.get_state_buf_infos()
            )
            kv_args.state_data_ptrs += draft_state_data_ptrs
            kv_args.state_data_lens += draft_state_data_lens
            kv_args.state_item_lens += draft_state_item_lens
    else:
        kv_args.state_type = "none"
else:
    kv_args.state_data_ptrs = []
    kv_args.state_data_lens = []
    kv_args.state_item_lens = []
    kv_args.state_type = "none"
python/sglang/srt/disaggregation/prefill.py core-logic

该文件是 Prefill 节点初始化 KV 管理器的入口,与 decode.py 做对称修改,确保 prefill 侧也传递 draft 状态。

# PrefillBootstrapQueue._init_kv_manager (partial)
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
    state_data_ptrs, state_data_lens, state_item_lens = (
        self.token_to_kv_pool.get_state_buf_infos()
    )
    kv_args.state_data_ptrs = state_data_ptrs
    kv_args.state_data_lens = state_data_lens
    kv_args.state_item_lens = state_item_lens
​
    if isinstance(self.token_to_kv_pool, SWAKVPool):
        kv_args.state_type = "swa"
    elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
        kv_args.state_type = "mamba"
        if hasattr(self.token_to_kv_pool, "get_state_dim_per_tensor"):
            kv_args.state_dim_per_tensor = self.token_to_kv_pool.get_state_dim_per_tensor()
    elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
        kv_args.state_type = "nsa"
        # 修复:将 draft 模型(MTP 模块)的 NSA 状态缓冲区也加入传输列表
        if self.draft_token_to_kv_pool is not None and isinstance(
            self.draft_token_to_kv_pool, NSATokenToKVPool
        ):
            (draft_state_data_ptrs, draft_state_data_lens, draft_state_item_lens) = (
                self.draft_token_to_kv_pool.get_state_buf_infos()
            )
            kv_args.state_data_ptrs += draft_state_data_ptrs
            kv_args.state_data_lens += draft_state_data_lens
            kv_args.state_item_lens += draft_state_item_lens
    else:
        kv_args.state_type = "none"
else:
    kv_args.state_data_ptrs = []
    kv_args.state_data_lens = []
    kv_args.state_item_lens = []
    kv_args.state_type = "none"

评论区精华

是否需要确保 draft_token_to_kv_pool 是 NSATokenToKVPool 以及检查 get_state_buf_infos 存在 正确性

kpham-sgl 提出两个问题:1) 是否必须 draft 池也是 NSATokenToKVPool?2) 是否需要检查 hasattr(self.draft_token_to_kv_pool, "get_state_buf_infos")?ShangmingCai 询问是否需要考虑非 MTP spec decode 场景。

结论:作者 zRzRzRzRzRzRzR 解释:isinstance 保护确保仅对 NSA draft 池生效,非 NSA 不受影响;get_state_buf_infos 是 NSATokenToKVPool 的方法,isinstance 已保证存在,无需额外 hasattr;当前仅 MTP-on-NSA 路径触发,未来其他路径同样适用。 · 已解决

风险与影响

技术风险:

  • 仅处理了 NSA 类型的 draft 池,若未来引入其他支持状态传输的池类型(如 SWA、Mamba),需类似扩展,否则可能遗漏传输。
  • 列表的 += 操作依赖于 kv_args.state_data_ptrs 已被正确初始化为目标模型的状态列表,在代码路径中此前提成立。
  • 缺少新增的单元测试覆盖,依赖已有的集成测试(如 8 卡测试)验证回归。
  • draft_token_to_kv_pool 为 None 或不是 NSATokenToKVPool,逻辑不执行,无副作用。

影响范围:

  • 仅影响同时启用 NSA 注意力、MTP 推测解码和 PD 分离部署的用户。在此之前,此类配置下推测解码结果错误(可能为乱码或低接受率);修复后恢复正常。
  • 不改变其他配置的行为,无性能影响。
  • 团队合并后经过 PD 相关 CI 验证通过。
核心路径变更 缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论