Prhub

#26539 [PD][MoRI] Align hybrid state transfer with per-component schema

原始 PR 作者 maning00 合并时间 2026-05-29 15:54 文件变更 2 提交数 12 评论 8 代码增减 +182 / -86

执行摘要

修复 MoRI 后端混合状态传输组件感知架构,消除 PD 引导崩溃

由于PR #24932重构混合状态传输为每组件架构,但MoRI的_register_kv_args、send_state、TransferInfo等仍使用扁平状态假设,导致struct.error: required argument is not an integer崩溃(issue #26525)。此PR将MoRI与Mooncake/NIXL已经使用的每组件调度模型对齐。

值得精读,特别是conn.py中组件分发逻辑和序列化方案的设计抉择(pack_int_lists vs 自定义msgpack)。开发者可以学习如何将一种传输后端与新的架构对齐,以及如何设计兼容旧格式的升级路径。

讨论亮点

复审者ShangmingCai在conn.py的send_state函数中建议简化state_types获取:"state_types = self.kv_args.state_types is enough",并在另一位置提出类似简化。作者maning00接受并更新代码。该讨论确保了代码简洁性,且维持了正确性假设(state_types必定已设置)。

实现拆解

  1. 新增序列化/反序列化辅助函数:在conn.py中添加_normalize_state_indices_per_component、_pack_state_indices、_unpack_state_indices、_pack_mem_desc_lists、_unpack_mem_desc_lists,分别处理组件感知的状态索引和MemoryDesc列表的打包/解包。新增函数使用pack_int_lists/unpack_int_lists(从common.utils导入)和msgspec.msgpack嵌套编码。

  2. 升级数据结构定义:TransferInfo.dst_state_indices从npt.NDArray[np.int32]改为List[npt.NDArray[np.int32]];KVArgsRegisterInfo.dst_state_mem_descs从List[MemoryDesc]改为List[List[MemoryDesc]],dst_state_item_lens和dst_state_dim_per_tensor从List[int]改为List[List[int]];MoriKVManager.state_mem_descs也改为List[List[MemoryDesc]]。

  3. 改写send_state分发逻辑:从直接使用扁平state_indices改为迭代state_types[i],对每个组件判断类型(Mamba或SWA/DSA)后分派到_send_mamba_state或_send_swa_dsa_state独立传输。

  4. 更新注册和元数据路径:_register_kv_args、send_metadata等函数适配新格式,使用新打包函数序列化state_mem_descs和state_indices。

  5. 测试配套:添加TestMoriTransferEngineHybridMambaE2E测试类,继承MoriTransferEngineBase,使用DEFAULT_HYBRID_MAMBA_MODEL_NAME_FOR_TEST模型在8 GPU上运行烟雾测试,验证混合状态传输正确性。同时增加模型选择的可扩展性支持。

文件 模块 状态 重要度
python/sglang/srt/disaggregation/mori/conn.py 传输层 modified 8.63
test/registered/amd/disaggregation/test_mori_transfer_engine_e2e.py MoRI 测试 modified 5.89

关键符号

_normalize_state_indices_per_component _pack_state_indices _unpack_state_indices _pack_mem_desc_lists _unpack_mem_desc_lists send_state _register_kv_args _send_mamba_state _send_swa_dsa_state

关键源码片段

python/sglang/srt/disaggregation/mori/conn.py core-logic

核心文件,实现所有组件感知的序列化、数据结构和传输分发逻辑,是此 PR 的主要变更点。

def _normalize_state_indices_per_component(
    state_indices: Optional[List],
) -> Optional[List[Optional[npt.NDArray[np.int32]]]]:
    # 将每组件状态索引规范化为 ravel 后的数组列表
    if state_indices is None:
        return None
    out: List[Optional[npt.NDArray[np.int32]]] = []
    for entry in state_indices:
        if entry is None:
            out.append(None)
        else:
            out.append(np.asarray(entry, dtype=np.int32).ravel())
    return out
​
​
def _pack_state_indices(
    state_indices: Optional[List[Optional[npt.NDArray[np.int32]]]],
) -> bytes:
    # 将组件状态索引列表打包为字节流,使用 pack_int_lists ( 格式 "i")
    if not state_indices:
        return b""
    lists = [(arr.tolist() if arr is not None else []) for arr in state_indices]
    return pack_int_lists(lists, "i")
​
​
def _unpack_state_indices(buf: bytes) -> List[npt.NDArray[np.int32]]:
    # 从字节流解包为组件状态索引数组列表
    if not buf:
        return []
    return [np.asarray(lst, dtype=np.int32) for lst in unpack_int_lists(buf, "i")]
​
​
def _pack_mem_desc_lists(mems_per_comp: List[List[MemoryDesc]]) -> bytes:
    # 将每组件 MemoryDesc 列表打包为嵌套 msgpack 字节流
    if not mems_per_comp:
        return b""
    return msgspec.msgpack.encode(
        [[mem.pack() for mem in comp] for comp in mems_per_comp]
    )
​
​
def _unpack_mem_desc_lists(blob: bytes) -> List[List[MemoryDesc]]:
    # 从嵌套 msgpack 字节流解包为每组件 MemoryDesc 列表
    if not blob:
        return []
    nested = msgspec.msgpack.decode(blob)
    return [[MemoryDesc.unpack(b) for b in comp] for comp in nested]

评论区精华

简化 state_types 获取的反馈 正确性

ShangmingCai 在 conn.py 的 send_state 函数中建议直接使用 `state_types = self.kv_args.state_types` 替代额外处理,因为确保该值已被设置。

结论:作者 maning00 接受建议并更新代码。 · 已解决

风险与影响

  1. 序列化兼容性:state_indices序列化从np.frombuffer改为pack_int_lists/unpack_int_lists,与其他后端(Mooncake/NIXL)独立,但若未来共享序列化工具需注意对齐。
  2. 性能影响:新增的列表转换(numpy与Python列表互转)在大型状态传输中可能有微小开销,但每次传输仅一次序列化,影响可忽略。
  3. 测试覆盖有限:仅一个混合模型烟雾测试,未覆盖多种头配置或极端大小,但已覆盖基本混合场景。
  4. 状态类型顺序假设:若state_types与indices顺序不匹配会导致错误传输,但代码确保从同一kv_args取值。

对用户:修复了使用MoRI后端的PD引导崩溃,使混合状态模型(如Qwen3.5、DeepSeek V4)可正常使用。对系统:wire格式和API调整,但与其他传输后端解耦,不会影响Mooncake/NIXL。对团队:简化MoRI代码结构,使其更易于维护,并与其他后端模式统一。

核心路径变更 序列化兼容性 测试覆盖有限

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论