Prhub

#38502 [ROCm] Cap Triton paged attention block size to fix ROCm shared memory OOM

原始 PR 作者 AndreasKaratzas 合并时间 2026-05-10 18:03 文件变更 3 提交数 29 评论 41 代码增减 +44 / -18

执行摘要

限制 Triton 注意力块大小为 128 修复 ROCm 共享内存溢出

Hybrid Mamba models (e.g. Jamba) inflate block_size to 2048 to align attention and Mamba page sizes. When the ROCm custom paged attention kernel rejects this (it only supports 16/32), the Triton fallback kernel_paged_attention_2d used 2048 as its tile size, requesting 262144 bytes of shared memory and thus exceeding the MI325X hardware limit of 65536 bytes. Cap TRITON_BLOCK_SIZE at 128. The kernel already decouples tile size from physical block size via l_block_idx/internal_offsets addressing, so this is safe.

值得仔细阅读,尤其是 has_native_kv_cache_layout 的设计和 TRITON_BLOCK_SIZE 硬上限的权衡。展示了如何通过块布局检测与内核选择协同解决硬件兼容性问题,对理解 ROCm 注意力实现和 Hybrid Mamba 支持有参考价值。

讨论亮点
  • 硬编码 128 上限 vs 动态计算:gemini-code-assist 建议动态计算最大块大小保持通用性。Andreas 回应这是 Triton 3.6 的问题,128 对当前模型安全,未来版本会修复。
  • 连续性检测方法:hmellor 提出使用 is_contiguous(),Andreas 解释需要块级而非张量级连续性,当前 stride 比较更高效且正确。
  • 是否仅 ROCm 特定:hmellor 询问 128 上限是否 ROCm 特有,Andreas 与 tjtanaa 确认该操作仅用于 ROCm。
  • kernel 选择对齐:gshtras 指出原生 kernel 选择逻辑未同步,Andreas 随后更新 use_custom 决策以与 cache update 路径保持一致。
  • ssd_chunk_scan 修改的撤回:最初对 ssd_chunk_scan.py 的修正被视为必要,后经 micah-wil 提醒确认 Triton PR#9541 已修复编译器崩溃,遂还原该文件。

实现拆解

  1. 新增连续布局检测函数:在 vllm/v1/attention/ops/chunked_prefill_paged_decode.py 中新增 has_native_kv_cache_layout 函数,通过比较 key_cache.stride(0)key_cache.shape[1:].numel() 判断 KV cache 块是否连续,从而决定使用原生 HIP 还是 Triton 路径。
  2. 修改 KV cache 写入路径:在 vllm/v1/attention/backends/rocm_attn.pydo_kv_cache_update 方法中,原条件 block_size in (16, 32) 不足以安全使用原生 write_to_paged_cache;新增 has_native_layout 检查后,只有块大小标准且布局连续时才走原生路径,否则回退到 Triton 的 reshape_and_cache_flash
  3. 限制 Triton tile 大小:在 chunked_prefill_paged_decode 中引入 MAX_TRITON_BLOCK_SIZE = 128,对 Triton 的 tile 大小硬上限,避免从物理 block_size(如 2048)推导出过大 tile 导致共享内存溢出。同时对于非连续布局(not has_native_layout)强制禁用原生 kernel(use_custom = False),确保解码路径与缓存更新路径一致。
  4. 更新 CI 配置:在 .buildkite/test-amd.yaml 中添加 hybrid_model 测试步骤,安装来自作者分支的 mamba 与 causal-conv1d 依赖,运行混合模型测试以验证修复效果。
文件 模块 状态 重要度
vllm/v1/attention/backends/rocm_attn.py 注意力后端 modified 6.26
vllm/v1/attention/ops/chunked_prefill_paged_decode.py 解码算子 modified 6.16
.buildkite/test-amd.yaml CI 配置 modified 3.49

关键符号

has_native_kv_cache_layout do_kv_cache_update chunked_prefill_paged_decode

关键源码片段

vllm/v1/attention/backends/rocm_attn.py core-logic

核心 ROCm 注意力后端,修改 do_kv_cache_update 方法,根据块连续性和大小选择正确的 KV 缓存写入路径。

    def do_kv_cache_update(
        self,
        layer: AttentionLayer,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
    ):
        if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
            return
        key_cache, value_cache = PagedAttention.split_kv_cache(
            kv_cache, self.num_kv_heads, self.head_size
        )
​
        # 从 value_cache 获取实际 block_size
        # value_cache 形状 : [num_blocks, num_heads, head_size, block_size]
        block_size = value_cache.shape[3]
        # 检查 KV cache 布局是否连续
        has_native_layout = has_native_kv_cache_layout(key_cache, value_cache)
​
        if block_size in (16, 32) and has_native_layout:
            # 标准 16, 32 且布局连续:使用 vLLM 原生 HIP C++ 逻辑
            PagedAttention.write_to_paged_cache(
                key,
                value,
                key_cache,
                value_cache,
                slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )
        else:
            # 非标准块大小或混合注意力 /Mamba 布局需要 stride-aware 的
            # Triton 写入器。原生 reshape_and_cache kernel 假设连续的
            # 块存储,写入混合缓存块时会出错。
            triton_reshape_and_cache_flash(
                key,
                value,
                key_cache,
                value_cache,
                slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )
vllm/v1/attention/ops/chunked_prefill_paged_decode.py infrastructure

定义了 has_native_kv_cache_layout 函数和 MAX_TRITON_BLOCK_SIZE 上限,是修复的核心。

# has_native_kv_cache_layout 检测 KV cache 块布局是否连续
# 原生 reshape_and_cache 写入器假设 packed 块;如果缓存更新需要
# reshape_and_cache_flash(针对 stride-padded 混合布局),则解码
# 也应使用匹配的 Triton 路径。
def has_native_kv_cache_layout(
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
) -> bool:
    return (
        key_cache.stride(0) == key_cache.shape[1:].numel()
        and value_cache.stride(0) == value_cache.shape[1:].numel()
    )# ... 在 chunked_prefill_paged_decode 函数内部 ...
# Triton tile 大小的硬上限,解决 Hybrid Mamba 模型因 block_size 为 2048
# 导致共享内存申请超过 65536 字节(MI325X 硬件限制)的问题。
# 内核通过 l_block_idx / internal_offsets 解耦 tile 与物理块大小,
# 因此 128 的上限不影响功能正确性。
MAX_TRITON_BLOCK_SIZE = 128
# 对于 2 的幂的 block_size,使用 min(block_size, 128) 确保不超限;
# 非 2 的幂(如 544)则固定使用 32。
TRITON_BLOCK_SIZE = min(block_size, MAX_TRITON_BLOCK_SIZE) if is_pow2 else 32

评论区精华

硬编码 MAX_TRITON_BLOCK_SIZE 与动态计算 性能

gemini-code-assist 建议根据设备共享内存动态计算最大 block 大小以确保通用性。Andreas 回应当前 128 对大部分模型安全,并指出这是 Triton 3.6 问题,未来版本会修复。

结论:维持硬编码 128,后续可跟进动态计算。 · 已解决

使用 is_contiguous 替代 stride 比较 正确性

hmellor 提出使用 PyTorch 的 is_contiguous() 方法检测连续性。Andreas 解释需要块级连续性而非张量级,is_contiguous 是张量级,而 stride 比较更高效且正确。

结论:保留 stride 比较方式。 · 已解决

128 上限是否仅 ROCm 特有 question

hmellor 询问 128 上限是否 ROCm 特有限制。Andreas 和 tjtanaa 确认该操作仅用于 ROCm,且上限是当前 ROCm Triton 的限制。

结论:确认仅适用于 ROCm。 · 已解决

原生 kernel 选择逻辑与 cache update 路径对齐 设计

gshtras 指出 do_kv_cache_update 中的连续性检查也需要在 kernel 选择中体现。Andreas 随后更新了 use_custom 决策,当布局非连续时强制设为 False,确保 decode 路径与 cache update 一致。

结论:已同步更新 kernel 选择逻辑。 · 已解决

ssd_chunk_scan 修改的撤回 正确性

最初在 ssd_chunk_scan.py 中修改以绕过 AMD Triton 编译器崩溃,后经 micah-wil 提醒被 Triton PR#9541 修复,Andreas 测试后确认不再需要,还原该文件。

结论:ssd_chunk_scan.py 已恢复原状。 · 已解决

风险与影响

  • 硬编码性能限制MAX_TRITON_BLOCK_SIZE=128 可能限制未来更大块大小模型的性能,但目前所有已知适用模型块大小 ≤ 128,无影响。
  • 连续性判断开销:每次 do_kv_cache_update 调用均执行 stride 比较,开销极低可忽略。
  • CI 依赖非官方分支:测试步骤依赖作者 fork 的 mamba 仓库,可能不稳定或未及时上游合并,长期应将依赖切换至官方版本。
  • 路径一致性风险:若未来新增其他块大小或布局,需同步更新 do_kv_cache_updatechunked_prefill_paged_decode 中的条件,否则可能出现写与读路径不匹配。
  • 用户影响:ROCm 用户现在可以运行 Hybrid Mamba 模型(如 Jamba),避免之前因共享内存 OOM 而崩溃。对其他模型无负面影响。
  • 系统影响:仅修改 ROCm 注意力后端与非标准块大小的 Triton 路径,非 ROCm 平台不受影响。性能方面标准模型(block_size 16/32)仍使用原生路径,无回归;非标准块模型(如 544)从不可用变为可用,但 Triton 路径可能稍慢。
  • 团队影响:需维护两套写入路径和对应的连续性检查逻辑,增加了代码复杂度。
硬编码共享内存限制 ROCm 路径变更 CI 依赖作者分支 写入与解码路径一致性

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论