执行摘要
- 一句话:限制 Triton 注意力块大小为 128 修复 ROCm 共享内存溢出
- 推荐动作:值得仔细阅读,尤其是
has_native_kv_cache_layout 的设计和 TRITON_BLOCK_SIZE 硬上限的权衡。展示了如何通过块布局检测与内核选择协同解决硬件兼容性问题,对理解 ROCm 注意力实现和 Hybrid Mamba 支持有参考价值。
功能与动机
Hybrid Mamba models (e.g. Jamba) inflate block_size to 2048 to align attention and Mamba page sizes. When the ROCm custom paged attention kernel rejects this (it only supports 16/32), the Triton fallback kernel_paged_attention_2d used 2048 as its tile size, requesting 262144 bytes of shared memory and thus exceeding the MI325X hardware limit of 65536 bytes. Cap TRITON_BLOCK_SIZE at 128. The kernel already decouples tile size from physical block size via l_block_idx/internal_offsets addressing, so this is safe.
实现拆解
- 新增连续布局检测函数:在
vllm/v1/attention/ops/chunked_prefill_paged_decode.py 中新增 has_native_kv_cache_layout 函数,通过比较 key_cache.stride(0) 与 key_cache.shape[1:].numel() 判断 KV cache 块是否连续,从而决定使用原生 HIP 还是 Triton 路径。
- 修改 KV cache 写入路径:在
vllm/v1/attention/backends/rocm_attn.py 的 do_kv_cache_update 方法中,原条件 block_size in (16, 32) 不足以安全使用原生 write_to_paged_cache;新增 has_native_layout 检查后,只有块大小标准且布局连续时才走原生路径,否则回退到 Triton 的 reshape_and_cache_flash。
- 限制 Triton tile 大小:在
chunked_prefill_paged_decode 中引入 MAX_TRITON_BLOCK_SIZE = 128,对 Triton 的 tile 大小硬上限,避免从物理 block_size(如 2048)推导出过大 tile 导致共享内存溢出。同时对于非连续布局(not has_native_layout)强制禁用原生 kernel(use_custom = False),确保解码路径与缓存更新路径一致。
- 更新 CI 配置:在
.buildkite/test-amd.yaml 中添加 hybrid_model 测试步骤,安装来自作者分支的 mamba 与 causal-conv1d 依赖,运行混合模型测试以验证修复效果。
关键文件:
vllm/v1/attention/backends/rocm_attn.py(模块 注意力后端;类别 source;类型 core-logic): 核心 ROCm 注意力后端,修改 do_kv_cache_update 方法,根据块连续性和大小选择正确的 KV 缓存写入路径。
vllm/v1/attention/ops/chunked_prefill_paged_decode.py(模块 解码算子;类别 infra;类型 infrastructure;符号 has_native_kv_cache_layout): 定义了 has_native_kv_cache_layout 函数和 MAX_TRITON_BLOCK_SIZE 上限,是修复的核心。
.buildkite/test-amd.yaml(模块 CI配置;类别 config;类型 configuration): 添加 hybrid_model 测试步骤,确保修复后的功能在 CI 中验证。
关键符号:has_native_kv_cache_layout, do_kv_cache_update, chunked_prefill_paged_decode
关键源码片段
vllm/v1/attention/backends/rocm_attn.py
核心 ROCm 注意力后端,修改 do_kv_cache_update 方法,根据块连续性和大小选择正确的 KV 缓存写入路径。
def do_kv_cache_update(
self,
layer: AttentionLayer,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
# 从 value_cache 获取实际 block_size
# value_cache 形状 : [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
# 检查 KV cache 布局是否连续
has_native_layout = has_native_kv_cache_layout(key_cache, value_cache)
if block_size in (16, 32) and has_native_layout:
# 标准 16, 32 且布局连续:使用 vLLM 原生 HIP C++ 逻辑
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
# 非标准块大小或混合注意力 /Mamba 布局需要 stride-aware 的
# Triton 写入器。原生 reshape_and_cache kernel 假设连续的
# 块存储,写入混合缓存块时会出错。
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
vllm/v1/attention/ops/chunked_prefill_paged_decode.py
定义了 has_native_kv_cache_layout 函数和 MAX_TRITON_BLOCK_SIZE 上限,是修复的核心。
# has_native_kv_cache_layout 检测 KV cache 块布局是否连续
# 原生 reshape_and_cache 写入器假设 packed 块;如果缓存更新需要
# reshape_and_cache_flash(针对 stride-padded 混合布局),则解码
# 也应使用匹配的 Triton 路径。
def has_native_kv_cache_layout(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
) -> bool:
return (
key_cache.stride(0) == key_cache.shape[1:].numel()
and value_cache.stride(0) == value_cache.shape[1:].numel()
)
# ... 在 chunked_prefill_paged_decode 函数内部 ...
# Triton tile 大小的硬上限,解决 Hybrid Mamba 模型因 block_size 为 2048
# 导致共享内存申请超过 65536 字节(MI325X 硬件限制)的问题。
# 内核通过 l_block_idx / internal_offsets 解耦 tile 与物理块大小,
# 因此 128 的上限不影响功能正确性。
MAX_TRITON_BLOCK_SIZE = 128
# 对于 2 的幂的 block_size,使用 min(block_size, 128) 确保不超限;
# 非 2 的幂(如 544)则固定使用 32。
TRITON_BLOCK_SIZE = min(block_size, MAX_TRITON_BLOCK_SIZE) if is_pow2 else 32
评论区精华
风险与影响
- 风险:
- 硬编码性能限制:
MAX_TRITON_BLOCK_SIZE=128 可能限制未来更大块大小模型的性能,但目前所有已知适用模型块大小 ≤ 128,无影响。
- 连续性判断开销:每次
do_kv_cache_update 调用均执行 stride 比较,开销极低可忽略。
- CI 依赖非官方分支:测试步骤依赖作者 fork 的 mamba 仓库,可能不稳定或未及时上游合并,长期应将依赖切换至官方版本。
- 路径一致性风险:若未来新增其他块大小或布局,需同步更新
do_kv_cache_update 与 chunked_prefill_paged_decode 中的条件,否则可能出现写与读路径不匹配。
- 影响:
- 用户影响:ROCm 用户现在可以运行 Hybrid Mamba 模型(如 Jamba),避免之前因共享内存 OOM 而崩溃。对其他模型无负面影响。
- 系统影响:仅修改 ROCm 注意力后端与非标准块大小的 Triton 路径,非 ROCm 平台不受影响。性能方面标准模型(block_size 16/32)仍使用原生路径,无回归;非标准块模型(如 544)从不可用变为可用,但 Triton 路径可能稍慢。
- 团队影响:需维护两套写入路径和对应的连续性检查逻辑,增加了代码复杂度。
- 风险标记:硬编码共享内存限制, ROCm 路径变更, CI 依赖作者分支, 写入与解码路径一致性
关联脉络
参与讨论