Prhub

#44539 [mamba] unify KDA conv states into one cache to match 2-state SSM layout

原始 PR 作者 ZJY0516 合并时间 2026-06-05 02:38 文件变更 3 提交数 5 评论 0 代码增减 +16 / -30

执行摘要

统一 KDA 卷积状态为一个缓存

PR body 明确说明:"Purpose align with GDN for future PD support, see https://github.com/vllm-project/vllm/pull/44064"。即为了与 GDN 架构对齐,以便将来支持前缀缓存(PD),需统一 KDA 的卷积状态布局。

建议精读。这是一个典型的数据结构统一重构,展示了如何在不改变内核计算逻辑的前提下,通过调整状态分组来适配更通用的架构。对于理解 vLLM 中 Mamba 系列(特别是 GDN/KDA)的缓存设计很有帮助。

讨论亮点

PR 没有 review 评论,reviewer WoosukKwon 和 tdoublep 均直接 approve,tdoublep 评论 "nice clean up"。

实现拆解

  1. mamba_utils.py: 修改 MambaStateDtypeCalculator.kda_state_dtype 返回类型从 (dtype, dtype, dtype, float32) 变为 (dtype, float32);修改 MambaStateShapeCalculator.kda_state_shape:将三个卷积形状(conv_state_shapeconv_state_k_shapeconv_state_k_shape)合并为一个聚合形状 conv_dim = proj_size + 2 * proj_k_size,返回类型从 4 元组变为 (conv_state_shape, recurrent_state_shape);修改 kda_state_copy_func 从返回 4 个复制函数变为 2 个。
  2. kimi_gdn_linear_attn.py: 更新 KimiGatedDeltaNetAttention.get_state_dtypeget_state_shape 返回值类型从 4 元组变为 2 元组;在 _forward 中将解包缓存从 (conv_state_q, conv_state_k, conv_state_v, recurrent_state) 改为 (conv_state, recurrent_state),然后通过 conv_state.chunk(3, dim=-2) 拆出 q/k/v 的 conv state,保持下层 kernel 调用不变。
  3. kimi_linear.py: 更新 get_mamba_state_dtype_from_configget_mamba_state_shape_from_configget_mamba_state_copy_func 的返回类型从 4 元组调整为 2 元组,与上述接口匹配。
文件 模块 状态 重要度
vllm/model_executor/layers/mamba/mamba_utils.py Mamba 工具 modified 6.84
vllm/model_executor/layers/mamba/gdn/kimi_gdn_linear_attn.py Kimi GDN modified 6.3
vllm/model_executor/models/kimi_linear.py Kimi 模型 modified 5.71

关键符号

kda_state_dtype kda_state_shape kda_state_copy_func get_state_dtype get_state_shape _forward

关键源码片段

vllm/model_executor/layers/mamba/mamba_utils.py data-contract

核心变更文件:修改了 KDA 的 dtype、shape 和 copy_func,从 4 状态压缩为 2 状态,定义了新的数据契约。

# vllm/model_executor/layers/mamba/mamba_utils.py
# 修改后:kda_state_dtype 返回 (state_dtype, float32)
@classmethod
def kda_state_dtype(cls, model_dtype: ModelDType | torch.dtype, mamba_cache_dtype: MambaDType) -> tuple[torch.dtype, torch.dtype]:
    state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
    return (state_dtype, torch.float32)# 修改后:kda_state_shape 返回 (conv_state_shape, recurrent_state_shape)
# conv_dim = proj_size + 2 * proj_k_size,合并原三个卷积状态为一个
@classmethod
def kda_state_shape(cls, tp_world_size: int, num_heads: int, head_dim: int,
                    num_k_heads: int | None = None, head_k_dim: int | None = None,
                    conv_kernel_size: int = 4, num_spec: int = 0) -> tuple[tuple[int, int], tuple[int, int, int]]:
    if num_k_heads is None:
        num_k_heads = num_heads
    if head_k_dim is None:
        head_k_dim = head_dim
    proj_size = num_heads * head_dim
    proj_k_size = num_k_heads * head_k_dim
    conv_dim = proj_size + 2 * proj_k_size
    conv_state_shape = cls._orient_conv_shape(divide(conv_dim, tp_world_size), conv_kernel_size - 1)
    recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
    return (conv_state_shape, recurrent_state_shape)# 修改后:kda_state_copy_func 返回 2 个函数
@classmethod
def kda_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
    return (get_conv_copy_spec, get_temporal_copy_spec)
vllm/model_executor/layers/mamba/gdn/kimi_gdn_linear_attn.py data-contract

消费端:更新了接口返回值类型并修改缓存解包逻辑,实际使用拆分后的 conv state。

# vllm/model_executor/layers/mamba/gdn/kimi_gdn_linear_attn.py
# 修改后:从缓存解包 conv_state 和 recurrent_state
(conv_state, recurrent_state) = constant_caches
# 如果布局不是 DS(dim first),则转置
if not is_conv_state_dim_first():
    conv_state = conv_state.transpose(-1, -2)
# 从合并的 conv_state 中拆出 q/k/v 各自的卷积状态
conv_state_q, conv_state_k, conv_state_v = conv_state.chunk(3, dim=-2)
# 后续使用 conv_state_q, conv_state_k, conv_state_v 调用 causal_conv1d_fn

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

风险较低。改动集中在 KDA 内部状态管理,对上层模型接口(kimi_linear.py)仅涉及返回类型签名变更,且被调用方(如 prefix caching 逻辑)在合并依赖此接口的 PR 之前已经适配。测试结果(GSM8K 0.89 accuracy)表明功能正确。潜在风险:若其他未包含在此 PR 中的模块直接解包 4 元组返回值,会因解包数量不匹配而报错。但基于仓库代码范围,所有使用点均已同步修改。

直接影响 Kimi-Linear 模型系列(如 moonshotai/Kimi-Linear-48B-A3B-Instruct)的推理缓存管理。缓存从 4 个 tensor 变为 2 个,降低了内存占用和复制开销,提升性能。对用户透明,无需修改配置或代码。对团队而言,此次重构为后续 GDN/PD 功能扫清了障碍。

数据契约变更 缺少测试文件变更

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论