执行摘要
- 一句话:为混合注意力+Mamba模型实现异构TP的3-read RDMA卷积状态传输,支持Prefill与Decode引擎TP大小不同。
- 推荐动作:该PR值得精读,尤其是对于从事分布式推理或Mamba模型优化的工程师。关注设计决策:3-read传输如何利用DS布局避免排列开销、HeteroTPTransferConfig作为单一数据源的处理方式、以及GQA头映射修正对准确性的关键影响。建议结合#37416和#37603理解整体演进脉络。
功能与动机
根据PR body描述,动机是“Enable prefill/decode disaggregation with different tensor parallelism sizes for hybrid attention+Mamba models”,即允许Prefill和Decode引擎使用不同的TP大小(如P_TP=1、D_TP=2),作为#37603中chunk-interleaved permutation方法的替代方案。通过3-read RDMA传输,消除P端和D端的排列逻辑,依赖DS卷积状态布局(在#37416中引入),使x、B、C子投影在内存中连续。
实现拆解
- 新增卷积状态分解工具:在
ssm_conv_transfer_utils.py中定义MambaConvSplitInfo数据类,用于计算每个TP rank的x、B、C字节大小和偏移量。derive_mamba_conv_split函数从MambaSpec推导分解信息,compute_mamba_phys_ratio计算每个引擎的物理块比例。
- 添加异构TP传输配置:在
utils.py中新增HeteroTPTransferConfig类,作为单一数据源处理FlashAttention和Mamba在不同异构TP场景下的描述符大小和读取目标,包括_physical_head_range函数修正GQA头映射。
- 改造NIXL连接器核心逻辑:在
nixl_connector.py中,新增_build_mamba_local和_build_mamba_remote等方法,实现3-read传输的描述符注册;集成HeteroTPTransferConfig以处理FA和Mamba的分离逻辑;修改_logical_to_remote_kernel_block_ids等方法支持远程物理块映射。
- 测试与配置配套:更新单元测试
test_nixl_connector_hma.py,添加对compute_mamba_phys_ratio的测试;修改集成测试脚本config_sweep_accuracy_test.sh,设置VLLM_SSM_CONV_STATE_LAYOUT=DS环境变量。
- 环境变量要求:新增断言要求
VLLM_SSM_CONV_STATE_LAYOUT=DS,确保卷积状态为DS布局。
关键文件:
vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py(模块 卷积传输工具;类别 source;类型 core-logic;符号 MambaConvSplitInfo, conv_dim_local, x_bytes, b_bytes): 新增卷积状态分解工具类,是3-read传输的基础,定义MambaConvSplitInfo和关键计算函数。
vllm/distributed/kv_transfer/kv_connector/utils.py(模块 传输配置;类别 source;类型 core-logic;符号 _physical_head_range, _range_overlap, HeteroTPTransferConfig, post_init): 新增HeteroTPTransferConfig类,作为异构TP传输的单一数据源,处理FA和Mamba的不同分割逻辑。
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py(模块 NIXL连接器;类别 source;类型 core-logic;符号 _build_mamba_local, _build_fa_remote_for_mamba, _build_mamba_remote, _logical_to_remote_kernel_block_ids): 核心NIXL连接器修改,集成3-read传输逻辑,新增Mamba相关方法和集成HeteroTPTransferConfig。
tests/v1/kv_connector/unit/test_nixl_connector_hma.py(模块 HMA单元测试;类别 test;类型 test-coverage;符号 test_compute_mamba_phys_ratio): 单元测试更新,验证compute_mamba_phys_ratio和Mamba描述符注册逻辑,确保异构TP支持的正确性。
关键符号:MambaConvSplitInfo, derive_mamba_conv_split, compute_mamba_phys_ratio, HeteroTPTransferConfig, _physical_head_range, _build_mamba_local, _build_fa_remote_for_mamba, _build_mamba_remote, _logical_to_remote_kernel_block_ids
关键源码片段
vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py
新增卷积状态分解工具类,是3-read传输的基础,定义MambaConvSplitInfo和关键计算函数。
@dataclass(frozen=True)
class MambaConvSplitInfo:
"""Per-rank byte sizes of x, B, C sub-projections in the Mamba conv state.
Used by both P and D sides for NIXL descriptor registration.
All fields are LOCAL to this engine's TP (already divided by TP size).
DS memory layout within one page (contiguous in memory):
|--- x (x_local * conv_rows) ---|- B (b_local * conv_rows) -|- C -|
"""
conv_rows: int # conv_kernel - 1 (typically 3)
x_local: int # intermediate_size / TP (columns for x)
b_local: int # groups_ss / TP (columns for B; C is same size)
conv_dtype_size: int # bytes per element (e.g. 2 for float16)
@property
def conv_dim_local(self) -> int:
"""Total conv columns per rank: x + B + C."""
return self.x_local + 2 * self.b_local
@property
def x_bytes(self) -> int:
"""Byte size of the x sub-projection for one rank."""
return self.x_local * self.conv_rows * self.conv_dtype_size
@property
def b_bytes(self) -> int:
"""Byte size of the B (or C) sub-projection for one rank."""
return self.b_local * self.conv_rows * self.conv_dtype_size
@property
def local_conv_offsets(self) -> list[tuple[int, int]]:
"""(byte_offset, byte_size) of x, B, C within this engine's page."""
xb = self.x_bytes
bb = self.b_bytes
return [(0, xb), (xb, bb), (xb + bb, bb)]
def remote_conv_offsets(self, local_rank_offset: int, tp_ratio: int) -> list[tuple[int, int]]:
"""(byte_offset, byte_size) for D rank's slice within P page."""
xb = self.x_bytes
bb = self.b_bytes
xr = xb * tp_ratio # full remote x section in bytes
br = bb * tp_ratio # full remote B section in bytes
return [
(local_rank_offset * xb, xb),
(xr + local_rank_offset * bb, bb),
(xr + br + local_rank_offset * bb, bb),
]
vllm/distributed/kv_transfer/kv_connector/utils.py
新增HeteroTPTransferConfig类,作为异构TP传输的单一数据源,处理FA和Mamba的不同分割逻辑。
def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range:
"""Physical KV head range stored in a rank's KV cache tensor.
When tp_size <= num_heads: sharded, K/TP contiguous heads per rank.
When tp_size > num_heads: 1 physical head per rank, distributed contiguously.
"""
if tp_size <= num_heads:
assert num_heads % tp_size == 0
per_rank = num_heads // tp_size
return range(rank * per_rank, (rank + 1) * per_rank)
else:
h = rank * num_heads // tp_size # 修正为连续分布,匹配vLLM的GQA权重分区
return range(h, h + 1)
@dataclass
class HeteroTPTransferConfig:
"""Precomputed transfer plan for one (D rank, P engine) pair.
Currently only instantiated for Mamba-HMA models where FA and mamba
require different splitting factors.
"""
# 输入参数
tp_ratio: int
K: int # total_num_kv_heads
d_tp: int # D engine's tensor_parallel_size
p_tp: int # P engine's tensor_parallel_size
d_rank: int # this D worker's TP rank
use_mla: bool
d_block_len: int # D's block_len_per_layer
p_block_len: int # P's block_len_per_layer
is_blocks_first: bool # kv_topo.is_kv_layout_blocks_first
# 派生属性,在__post_init__中计算
d_physical_heads: int = field(init=False)
p_physical_heads: int = field(init=False)
physical_fa_num_reads: int = field(init=False)
fa_read_targets: list[int] = field(init=False) # 唯一贡献FA头的P rank列表
mamba_read_targets: list[int] = field(init=False) # 唯一贡献Mamba状态的P rank列表
def __post_init__(self):
"""Compute physical heads and read targets based on GQA mapping."""
self.d_physical_heads = len(_physical_head_range(self.d_tp, self.K, self.d_rank))
self.p_physical_heads = len(_physical_head_range(self.p_tp, self.K, 0)) # 示例计算
# 进一步计算fa_read_targets和mamba_read_targets,处理复制场景
# ...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
核心NIXL连接器修改,集成3-read传输逻辑,新增Mamba相关方法和集成HeteroTPTransferConfig。
def _build_mamba_local(self, blocks_data: list[tuple[int, int, int]], base_addresses: list[int], block_size_ratio: int) -> list[tuple[int, int, int]]:
"""Register 4 desc regions (x, B, C, ssm) per layer for local mamba blocks.
Enables 3-read transfer without permutation. Each region corresponds to
a sub-projection of the conv state in DS layout.
"""
if not self._has_mamba or self._conv_decomp is None:
return []
conv_decomp = self._conv_decomp
mamba_regions = []
for base_addr in base_addresses:
# 为每个缓存张量注册x、B、C、ssm四个区域
for offset, size in conv_decomp.local_conv_offsets:
mamba_regions.append((base_addr + offset, size, block_size_ratio))
# 添加SSM区域,使用conv_decomp中的ssm大小计算
ssm_offset = conv_decomp.conv_dim_local * conv_decomp.conv_rows * conv_decomp.conv_dtype_size
mamba_regions.append((base_addr + ssm_offset, self._mamba_ssm_size[1], block_size_ratio))
return mamba_regions
def _logical_to_remote_kernel_block_ids(self, block_ids: BlockIds, remote_ratio: int) -> BlockIds:
"""Map logical block IDs to physical kernel block IDs on the remote.
Critical for hetero-TP where remote may have different physical block layout.
Early-exit uses remote_ratio (not local_ratio) to avoid data corruption.
"""
if remote_ratio == 1: # 修正:原为local_ratio,可能导致错误描述符读取
return block_ids
result = []
for group in block_ids:
mapped = [bid * remote_ratio for bid in group]
result.append(mapped)
return result
评论区精华
风险与影响
- 风险:- 回归风险:新增断言
is_conv_state_dim_first()可能导致非Mamba模型(如Qwen3.5)初始化失败,影响兼容性(文件nixl_connector.py)。
- 数据损坏风险:
_logical_to_remote_kernel_block_ids中早期退出逻辑原使用local_ratio,修正为remote_ratio,避免错误描述符读取(文件nixl_connector.py)。
- 性能风险:异构TP配置下FA和Mamba分离处理增加复杂性,但RDMA传输优化应抵消开销。
- 维护风险:代码分散在多个文件,复杂度较高,需后续重构以保持可维护性。
- 影响:- 用户影响:使能混合注意力+Mamba模型的异构TP部署,提升推理灵活性和资源利用率;但要求设置
VLLM_SSM_CONV_STATE_LAYOUT=DS,可能影响现有工作流。
- 系统影响:修改分布式KV传输核心路径,影响所有使用NIXL连接器的Mamba模型推理性能;测试显示在多种配置下保持高准确率(GSM8K测试通过)。
- 团队影响:引入新的传输机制和配置类,增加代码库复杂性,需团队熟悉;后续需扩展支持Mamba1和gdn_attention模型。
- 风险标记:非Mamba模型兼容性风险, 核心路径变更, 复杂度增加需后续重构
关联脉络
- PR #37603 [NIXL][Mamba][2/N] Heterogeneous TP: chunk-interleaved permutation: 同系列PR,提供替代的chunk-interleaved permutation方法,本PR的3-read传输作为优化替代。
- PR #37416 Introduce DS conv state layout for Mamba: 引入DS卷积状态布局(VLLM_SSM_CONV_STATE_LAYOUT=DS),是本PR3-read传输的基础依赖。
参与讨论