Prhub

#23893 [NPU]pp support mla kv transfer

原始 PR 作者 chenxu214 合并时间 2026-05-13 09:10 文件变更 6 提交数 2 评论 13 代码增减 +92 / -22

执行摘要

NPU Ascend 后端支持 PP 下 MLA KV 传输

NPU Ascend后端在流水线并行(PP)模式下,MLA的KV传输之前只支持MHA,需要适配MLA特有的层组织方式(可能包含多个buffer groups和draft层)。该PR解决了PP场景下MLA KV传输的正确性和性能问题。

建议同后端开发人员和关注disaggregation模块的工程师精读,特别是get_mla_kv_ptrs_with_pp中的层切片算法和setup_state_kv_args中NPUMLA的处理方式,该设计从硬编码演进为结构化参数,具有参考价值。

讨论亮点
  1. 硬编码与重构:审核者ShangmingCai指出早期版本使用kv_args.state_type == "nsa"来判断每层指针数的方式"hard-coded and hacky",作者通过引入kv_buf_groups字段重构设计,消除了硬编码。

  2. 字段注释明确性:ShangmingCai要求KVArgs新增字段注释中明确标注NPU/Ascend相关,作者已采纳并修改注释。

  3. 类型归属讨论:ShangmingCai询问NPUMLATokenToKVPool为何被标记为"nsa"类型,作者解释该池即将重构为NSA类型以与原MLA区分,ShangmingCai表示理解。

实现拆解

  1. 新增数据结构字段:在disaggregation/base/conn.pyKVArgs dataclass中增加了kv_buf_groupstotal_kv_layers两个int字段,用于描述MLA buffer分组和decode端总层数,注释明确标注为NPU专用。

  2. NPU内存池暴露状态信息:在hardware_backend/npu/memory_pool_npu.pyNPUMLATokenToKVPool中新增get_state_buf_infos方法,返回index_k_buffer的指针、大小和item大小。同时将torch_npu导入改为条件导入(if is_npu():),避免非NPU环境导入错误。

  3. 配置流调整:在disaggregation/utils.pysetup_state_kv_args函数中增加total_kv_layers参数,并将NPUMLATokenToKVPool识别为与NSATokenToKVPool并列的类型。当识别到NPUMLA池时,不直接追加state组件,而是设置kv_args.kv_buf_groups(由kv_data_ptrs长度除以层数计算)和kv_args.total_kv_layers,供后续传输层切片使用。

  4. 核心层映射实现:在disaggregation/ascend/conn.pyAscendKVManager中新增get_mla_kv_ptrs_with_pp方法,根据prefill_start_layerkv_buf_groupstotal_kv_layers将decode端KV指针列表切片,准确匹配prefill的层范围。在send_kvcache中根据self.is_mla_backend分支分别调用MLA或MHA的层映射逻辑。

  5. 调用方适配:在decode.pyprefill.py_init_kv_manager中,调用setup_state_kv_args时传入total_kv_layers(值为self.scheduler.model_config.num_hidden_layers),确保decode端能获知总层数。

  6. 测试与配套:本次变更未包含直接新增的测试文件,但通过CI的run-ci标签进行回归覆盖。

文件 模块 状态 重要度
python/sglang/srt/disaggregation/ascend/conn.py 通信层 modified 7.3
python/sglang/srt/hardware_backend/npu/memory_pool_npu.py 内存池 modified 6.34
python/sglang/srt/disaggregation/utils.py 配置层 modified 6.23
python/sglang/srt/disaggregation/base/conn.py 数据类 modified 5.07
python/sglang/srt/disaggregation/decode.py 解码端 modified 4.35
python/sglang/srt/disaggregation/prefill.py 预填充端 modified 4.18

关键符号

get_mla_kv_ptrs_with_pp get_state_buf_infos setup_state_kv_args

关键源码片段

python/sglang/srt/disaggregation/ascend/conn.py core-logic

核心实现文件,新增 get_mla_kv_ptrs_with_pp 方法处理 MLA PP 层映射,修改 send_kvcache 分支。

# 文件 : python/sglang/srt/disaggregation/ascend/conn.py
class AscendKVManager(MooncakeKVManager):
    def get_mla_kv_ptrs_with_pp(
        self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
    ) -> Tuple[List[int], List[int], int]:
        """计算 MLA 场景下 prefill 和 decode 端的 KV 指针层映射。
        因为 decode 端可能包含比 prefill 更多的层(如 speculative 算法添加的 draft 层),
        且 MLA 使用 kv_buf_groups 分组 k_data、v_data 和可选的 index_k_data。
        通过 prefill_start_layer 和 total_kv_layers 进行切片。
        """
        start_layer = self.kv_args.prefill_start_layer
        kv_buf_groups = getattr(self.kv_args, "kv_buf_groups", 1) # 每组包含的指针数
        total_kv_layers = getattr(self.kv_args, "total_kv_layers", 0) # decode 总层数
        src_layers = len(src_kv_ptrs) // kv_buf_groups # prefill 实际层数
​
        # 当 decode 端启用了 speculative 算法时,KV 会比 prefill 多一层 draft,需跳过
        dst_total_layers = (
            min(len(dst_kv_ptrs) // kv_buf_groups, total_kv_layers)
            if total_kv_layers
            else len(dst_kv_ptrs) // kv_buf_groups
        )
        end_layer = start_layer + src_layers
​
        if src_layers == dst_total_layers:
            sliced_dst_kv_ptrs = dst_kv_ptrs
        else:
            sliced_dst_kv_ptrs = []
            for i in range(kv_buf_groups):
                layer_offset = i * dst_total_layers
                sliced_dst_kv_ptrs.extend(
                    dst_kv_ptrs[layer_offset + start_layer: layer_offset + end_layer]
                )
        layers_current_pp_stage = len(src_kv_ptrs)
        return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
​
    def send_kvcache(self, ...):
        # ... 省略上下文 ...
        if self.pp_size > 1:
            if self.is_mla_backend:
                # MLA 分支使用新方法获取映射后的指针
                src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage = (
                    self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
                )
                layers_params = [
                    (
                        src_kv_ptrs[layer_id],
                        sliced_dst_kv_ptrs[layer_id],
                        self.kv_args.kv_item_lens[layer_id],
                    )
                    for layer_id in range(layers_current_pp_stage)
                ]
            else:
                # MHA 分支保持原有逻辑
                ...
python/sglang/srt/hardware_backend/npu/memory_pool_npu.py core-logic

新增 get_state_buf_infos 方法暴露 index_k_buffer 信息,条件导入 torch_npu 避免非 NPU 环境报错。

# 文件 : python/sglang/srt/hardware_backend/npu/memory_pool_npu.py
from sglang.srt.utils.common import is_npu# 仅在 NPU 环境下才导入 torch_npu,避免在其他硬件后端报错
if is_npu():
    import torch_npuclass NPUMLATokenToKVPool(MLATokenToKVPool):
    # ... 其他方法 ...
​
    def get_state_buf_infos(self):
        """返回 index_k_buffer 的指针、大小和 item 大小,
        供 disaggregation 流程注册 state 组件使用。
        """
        data_ptrs = [self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)]
        data_lens = [self.index_k_buffer[i].nbytes for i in range(self.layer_num)]
        item_lens = [self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)]
        return data_ptrs, data_lens, item_lens

评论区精华

硬编码设计 设计

ShangmingCai 指出早期版本使用 kv_args.state_type 判断指针数的方式 "hard-coded and hacky"。

结论:作者通过引入 kv_buf_groups 字段重构,消除了硬编码。 · 已解决

字段注释明确性 style

ShangmingCai 要求 KVArgs 新增字段注释中明确标注 NPU/Ascend 相关。

结论:作者已添加注释 "Only used of npu"。 · 已解决

类型归属讨论 设计

ShangmingCai 询问 NPUMLATokenToKVPool 是否继承自 NSATokenToKVPool,以及为何被标记为 "nsa" 类型。

结论:作者解释该池将重构为 NSA 类型以与原 MLA 区分,ShangmingCai 表示理解。 · 已解决

风险与影响

  1. 字段误用风险kv_buf_groupstotal_kv_layers仅在NPU路径下设置和使用,若在其他后端意外触发对应逻辑,可能导致未定义行为。当前通过isinstance检查和条件导入隔离了风险。
  2. 层切片计算错误get_mla_kv_ptrs_with_pp中的切片逻辑依赖于prefill_start_layertotal_kv_layers的准确性,若配置错误(如层数不匹配)会导致KV传输数据错位,造成推理错误。
  3. 缺少单元测试:没有直接对应的单元测试覆盖新逻辑,依赖CI集成测试,可能遗漏边界情况(如speculative禁用或启用时的层数差异)。

影响范围限定于NPU Ascend后端使用MLA模型并启用PP模式的KV分离传输场景。对非NPU后端、MHA模型或单机非PP模式无任何影响。用户无需更改配置即可在NPU上获得正确的MLA PP传输能力。

缺少直接单元测试 字段仅在 NPU 路径使用 层切片依赖外部配置准确性

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论