执行摘要
- 一句话:NPU Ascend后端支持PP下MLA KV传输
- 推荐动作:建议同后端开发人员和关注disaggregation模块的工程师精读,特别是
get_mla_kv_ptrs_with_pp中的层切片算法和setup_state_kv_args中NPUMLA的处理方式,该设计从硬编码演进为结构化参数,具有参考价值。
功能与动机
NPU Ascend后端在流水线并行(PP)模式下,MLA的KV传输之前只支持MHA,需要适配MLA特有的层组织方式(可能包含多个buffer groups和draft层)。该PR解决了PP场景下MLA KV传输的正确性和性能问题。
实现拆解
-
新增数据结构字段:在disaggregation/base/conn.py的KVArgs dataclass中增加了kv_buf_groups和total_kv_layers两个int字段,用于描述MLA buffer分组和decode端总层数,注释明确标注为NPU专用。
-
NPU内存池暴露状态信息:在hardware_backend/npu/memory_pool_npu.py的NPUMLATokenToKVPool中新增get_state_buf_infos方法,返回index_k_buffer的指针、大小和item大小。同时将torch_npu导入改为条件导入(if is_npu():),避免非NPU环境导入错误。
-
配置流调整:在disaggregation/utils.py的setup_state_kv_args函数中增加total_kv_layers参数,并将NPUMLATokenToKVPool识别为与NSATokenToKVPool并列的类型。当识别到NPUMLA池时,不直接追加state组件,而是设置kv_args.kv_buf_groups(由kv_data_ptrs长度除以层数计算)和kv_args.total_kv_layers,供后续传输层切片使用。
-
核心层映射实现:在disaggregation/ascend/conn.py的AscendKVManager中新增get_mla_kv_ptrs_with_pp方法,根据prefill_start_layer、kv_buf_groups和total_kv_layers将decode端KV指针列表切片,准确匹配prefill的层范围。在send_kvcache中根据self.is_mla_backend分支分别调用MLA或MHA的层映射逻辑。
-
调用方适配:在decode.py和prefill.py的_init_kv_manager中,调用setup_state_kv_args时传入total_kv_layers(值为self.scheduler.model_config.num_hidden_layers),确保decode端能获知总层数。
-
测试与配套:本次变更未包含直接新增的测试文件,但通过CI的run-ci标签进行回归覆盖。
关键文件:
python/sglang/srt/disaggregation/ascend/conn.py(模块 通信层;类别 source;类型 core-logic;符号 get_mla_kv_ptrs_with_pp): 核心实现文件,新增get_mla_kv_ptrs_with_pp方法处理MLA PP层映射,修改send_kvcache分支。
python/sglang/srt/hardware_backend/npu/memory_pool_npu.py(模块 内存池;类别 source;类型 core-logic;符号 get_state_buf_infos): 新增get_state_buf_infos方法暴露index_k_buffer信息,条件导入torch_npu避免非NPU环境报错。
python/sglang/srt/disaggregation/utils.py(模块 配置层;类别 source;类型 dependency-wiring): 修改setup_state_kv_args以处理NPUMLA池,传递total_kv_layers参数,设置kv_buf_groups。
python/sglang/srt/disaggregation/base/conn.py(模块 数据类;类别 source;类型 core-logic): 在KVArgs dataclass中新增kv_buf_groups和total_kv_layers字段,作为NPU MLA PP传输的数据结构支持。
python/sglang/srt/disaggregation/decode.py(模块 解码端;类别 source;类型 core-logic): 在_init_kv_manager中调用setup_state_kv_args时传入total_kv_layers参数。
python/sglang/srt/disaggregation/prefill.py(模块 预填充端;类别 source;类型 core-logic): 在_init_kv_manager中调用setup_state_kv_args时传入total_kv_layers参数。
关键符号:get_mla_kv_ptrs_with_pp, get_state_buf_infos, setup_state_kv_args
关键源码片段
python/sglang/srt/disaggregation/ascend/conn.py
核心实现文件,新增get_mla_kv_ptrs_with_pp方法处理MLA PP层映射,修改send_kvcache分支。
# 文件 : python/sglang/srt/disaggregation/ascend/conn.py
class AscendKVManager(MooncakeKVManager):
def get_mla_kv_ptrs_with_pp(
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
) -> Tuple[List[int], List[int], int]:
"""计算 MLA 场景下 prefill 和 decode 端的 KV 指针层映射。
因为 decode 端可能包含比 prefill 更多的层(如 speculative 算法添加的 draft 层),
且 MLA 使用 kv_buf_groups 分组 k_data、v_data 和可选的 index_k_data。
通过 prefill_start_layer 和 total_kv_layers 进行切片。
"""
start_layer = self.kv_args.prefill_start_layer
kv_buf_groups = getattr(self.kv_args, "kv_buf_groups", 1) # 每组包含的指针数
total_kv_layers = getattr(self.kv_args, "total_kv_layers", 0) # decode 总层数
src_layers = len(src_kv_ptrs) // kv_buf_groups # prefill 实际层数
# 当 decode 端启用了 speculative 算法时,KV 会比 prefill 多一层 draft,需跳过
dst_total_layers = (
min(len(dst_kv_ptrs) // kv_buf_groups, total_kv_layers)
if total_kv_layers
else len(dst_kv_ptrs) // kv_buf_groups
)
end_layer = start_layer + src_layers
if src_layers == dst_total_layers:
sliced_dst_kv_ptrs = dst_kv_ptrs
else:
sliced_dst_kv_ptrs = []
for i in range(kv_buf_groups):
layer_offset = i * dst_total_layers
sliced_dst_kv_ptrs.extend(
dst_kv_ptrs[layer_offset + start_layer: layer_offset + end_layer]
)
layers_current_pp_stage = len(src_kv_ptrs)
return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
def send_kvcache(self, ...):
# ... 省略上下文 ...
if self.pp_size > 1:
if self.is_mla_backend:
# MLA 分支使用新方法获取映射后的指针
src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage = (
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
)
layers_params = [
(
src_kv_ptrs[layer_id],
sliced_dst_kv_ptrs[layer_id],
self.kv_args.kv_item_lens[layer_id],
)
for layer_id in range(layers_current_pp_stage)
]
else:
# MHA 分支保持原有逻辑
...
python/sglang/srt/hardware_backend/npu/memory_pool_npu.py
新增get_state_buf_infos方法暴露index_k_buffer信息,条件导入torch_npu避免非NPU环境报错。
# 文件 : python/sglang/srt/hardware_backend/npu/memory_pool_npu.py
from sglang.srt.utils.common import is_npu
# 仅在 NPU 环境下才导入 torch_npu,避免在其他硬件后端报错
if is_npu():
import torch_npu
class NPUMLATokenToKVPool(MLATokenToKVPool):
# ... 其他方法 ...
def get_state_buf_infos(self):
"""返回 index_k_buffer 的指针、大小和 item 大小,
供 disaggregation 流程注册 state 组件使用。
"""
data_ptrs = [self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)]
data_lens = [self.index_k_buffer[i].nbytes for i in range(self.layer_num)]
item_lens = [self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)]
return data_ptrs, data_lens, item_lens
评论区精华
-
硬编码与重构:审核者ShangmingCai指出早期版本使用kv_args.state_type == "nsa"来判断每层指针数的方式"hard-coded and hacky",作者通过引入kv_buf_groups字段重构设计,消除了硬编码。
-
字段注释明确性:ShangmingCai要求KVArgs新增字段注释中明确标注NPU/Ascend相关,作者已采纳并修改注释。
-
类型归属讨论:ShangmingCai询问NPUMLATokenToKVPool为何被标记为"nsa"类型,作者解释该池即将重构为NSA类型以与原MLA区分,ShangmingCai表示理解。
- 硬编码设计 (design): 作者通过引入 kv_buf_groups 字段重构,消除了硬编码。
- 字段注释明确性 (style): 作者已添加注释 "Only used of npu"。
- 类型归属讨论 (design): 作者解释该池将重构为 NSA 类型以与原 MLA 区分,ShangmingCai 表示理解。
风险与影响
- 风险:
- 字段误用风险:
kv_buf_groups和total_kv_layers仅在NPU路径下设置和使用,若在其他后端意外触发对应逻辑,可能导致未定义行为。当前通过isinstance检查和条件导入隔离了风险。
- 层切片计算错误:
get_mla_kv_ptrs_with_pp中的切片逻辑依赖于prefill_start_layer和total_kv_layers的准确性,若配置错误(如层数不匹配)会导致KV传输数据错位,造成推理错误。
- 缺少单元测试:没有直接对应的单元测试覆盖新逻辑,依赖CI集成测试,可能遗漏边界情况(如speculative禁用或启用时的层数差异)。
- 影响:影响范围限定于NPU Ascend后端使用MLA模型并启用PP模式的KV分离传输场景。对非NPU后端、MHA模型或单机非PP模式无任何影响。用户无需更改配置即可在NPU上获得正确的MLA PP传输能力。
- 风险标记:缺少直接单元测试, 字段仅在NPU路径使用, 层切片依赖外部配置准确性
关联脉络
- PR #24595 [NPU] use causal_conv1d_update_v2 for performance: 同为 NPU 后端性能优化,属于 NPU 路线的演进。
- PR #25076 Fix fused_moe import for non-NPU devices: 涉及 NPU 条件导入模式,与本 PR 的 torch_npu 条件导入有相似性。
参与讨论