执行摘要
- 一句话:修复MoRI后端混合状态传输组件感知架构,消除PD引导崩溃
- 推荐动作:值得精读,特别是conn.py中组件分发逻辑和序列化方案的设计抉择(pack_int_lists vs 自定义msgpack)。开发者可以学习如何将一种传输后端与新的架构对齐,以及如何设计兼容旧格式的升级路径。
功能与动机
由于PR #24932重构混合状态传输为每组件架构,但MoRI的_register_kv_args、send_state、TransferInfo等仍使用扁平状态假设,导致struct.error: required argument is not an integer崩溃(issue #26525)。此PR将MoRI与Mooncake/NIXL已经使用的每组件调度模型对齐。
实现拆解
-
新增序列化/反序列化辅助函数:在conn.py中添加_normalize_state_indices_per_component、_pack_state_indices、_unpack_state_indices、_pack_mem_desc_lists、_unpack_mem_desc_lists,分别处理组件感知的状态索引和MemoryDesc列表的打包/解包。新增函数使用pack_int_lists/unpack_int_lists(从common.utils导入)和msgspec.msgpack嵌套编码。
-
升级数据结构定义:TransferInfo.dst_state_indices从npt.NDArray[np.int32]改为List[npt.NDArray[np.int32]];KVArgsRegisterInfo.dst_state_mem_descs从List[MemoryDesc]改为List[List[MemoryDesc]],dst_state_item_lens和dst_state_dim_per_tensor从List[int]改为List[List[int]];MoriKVManager.state_mem_descs也改为List[List[MemoryDesc]]。
-
改写send_state分发逻辑:从直接使用扁平state_indices改为迭代state_types[i],对每个组件判断类型(Mamba或SWA/DSA)后分派到_send_mamba_state或_send_swa_dsa_state独立传输。
-
更新注册和元数据路径:_register_kv_args、send_metadata等函数适配新格式,使用新打包函数序列化state_mem_descs和state_indices。
-
测试配套:添加TestMoriTransferEngineHybridMambaE2E测试类,继承MoriTransferEngineBase,使用DEFAULT_HYBRID_MAMBA_MODEL_NAME_FOR_TEST模型在8 GPU上运行烟雾测试,验证混合状态传输正确性。同时增加模型选择的可扩展性支持。
关键文件:
python/sglang/srt/disaggregation/mori/conn.py(模块 传输层;类别 source;类型 core-logic;符号 _normalize_state_indices, _normalize_state_indices_per_component, _pack_state_indices, _unpack_state_indices): 核心文件,实现所有组件感知的序列化、数据结构和传输分发逻辑,是此PR的主要变更点。
test/registered/amd/disaggregation/test_mori_transfer_engine_e2e.py(模块 MoRI测试;类别 test;类型 test-coverage;符号 TestMoriTransferEngineHybridMambaE2E, test_generate_smoke_hybrid_mamba): 添加混合Mamba模型状态传输回归测试,验证组件感知路径的正确性。
关键符号:_normalize_state_indices_per_component, _pack_state_indices, _unpack_state_indices, _pack_mem_desc_lists, _unpack_mem_desc_lists, send_state, _register_kv_args, _send_mamba_state, _send_swa_dsa_state
关键源码片段
python/sglang/srt/disaggregation/mori/conn.py
核心文件,实现所有组件感知的序列化、数据结构和传输分发逻辑,是此PR的主要变更点。
def _normalize_state_indices_per_component(
state_indices: Optional[List],
) -> Optional[List[Optional[npt.NDArray[np.int32]]]]:
# 将每组件状态索引规范化为 ravel 后的数组列表
if state_indices is None:
return None
out: List[Optional[npt.NDArray[np.int32]]] = []
for entry in state_indices:
if entry is None:
out.append(None)
else:
out.append(np.asarray(entry, dtype=np.int32).ravel())
return out
def _pack_state_indices(
state_indices: Optional[List[Optional[npt.NDArray[np.int32]]]],
) -> bytes:
# 将组件状态索引列表打包为字节流,使用 pack_int_lists ( 格式 "i")
if not state_indices:
return b""
lists = [(arr.tolist() if arr is not None else []) for arr in state_indices]
return pack_int_lists(lists, "i")
def _unpack_state_indices(buf: bytes) -> List[npt.NDArray[np.int32]]:
# 从字节流解包为组件状态索引数组列表
if not buf:
return []
return [np.asarray(lst, dtype=np.int32) for lst in unpack_int_lists(buf, "i")]
def _pack_mem_desc_lists(mems_per_comp: List[List[MemoryDesc]]) -> bytes:
# 将每组件 MemoryDesc 列表打包为嵌套 msgpack 字节流
if not mems_per_comp:
return b""
return msgspec.msgpack.encode(
[[mem.pack() for mem in comp] for comp in mems_per_comp]
)
def _unpack_mem_desc_lists(blob: bytes) -> List[List[MemoryDesc]]:
# 从嵌套 msgpack 字节流解包为每组件 MemoryDesc 列表
if not blob:
return []
nested = msgspec.msgpack.decode(blob)
return [[MemoryDesc.unpack(b) for b in comp] for comp in nested]
评论区精华
复审者ShangmingCai在conn.py的send_state函数中建议简化state_types获取:"state_types = self.kv_args.state_types is enough",并在另一位置提出类似简化。作者maning00接受并更新代码。该讨论确保了代码简洁性,且维持了正确性假设(state_types必定已设置)。
- 简化state_types获取的反馈 (correctness): 作者 maning00 接受建议并更新代码。
风险与影响
- 风险:
- 序列化兼容性:state_indices序列化从np.frombuffer改为pack_int_lists/unpack_int_lists,与其他后端(Mooncake/NIXL)独立,但若未来共享序列化工具需注意对齐。
- 性能影响:新增的列表转换(numpy与Python列表互转)在大型状态传输中可能有微小开销,但每次传输仅一次序列化,影响可忽略。
- 测试覆盖有限:仅一个混合模型烟雾测试,未覆盖多种头配置或极端大小,但已覆盖基本混合场景。
- 状态类型顺序假设:若state_types与indices顺序不匹配会导致错误传输,但代码确保从同一kv_args取值。
- 影响:对用户:修复了使用MoRI后端的PD引导崩溃,使混合状态模型(如Qwen3.5、DeepSeek V4)可正常使用。对系统:wire格式和API调整,但与其他传输后端解耦,不会影响Mooncake/NIXL。对团队:简化MoRI代码结构,使其更易于维护,并与其他后端模式统一。
- 风险标记:核心路径变更, 序列化兼容性, 测试覆盖有限
关联脉络
- PR #24932 [PD] Refactor hybrid state transfer: 该PR引入每组件状态架构,此PR完成MoRI后端的迁移对齐。
参与讨论