Prhub

#41812 [ROCm][DSv4] implement flash sparse mla with triton kernels

原始 PR 作者 whx-sjtu 合并时间 2026-05-12 00:27 文件变更 6 提交数 6 评论 13 代码增减 +1849 / -212

执行摘要

用 Triton 为 ROCm DeepSeekV4 稀疏 MLA 加速

DeepSeek V4 的稀疏 MLA 在 ROCm 上原有实现基于 torch reference,性能受限且无法充分利用 GPU 并发。该 PR 旨在通过 Triton 编写专门 kernel,提升推理吞吐。PR body 说明:'replace ROCm's torch reference implementation of deepseek v4 sparse mla with triton kernels to support larger concurrency and improve performance.'

该 PR 值得精读,尤其是新增的 Triton kernel 实现和 ROCm backend 集成方式。设计决策中,将 platform-specific 逻辑从 model layer 下沉到 backend 选择是良好的分离。但需关注 review 中提出的正确性风险是否在合并前解决。

讨论亮点

Review 中 gemini-code-assist[bot] 指出了三个关键问题:

  • KV cache 布局假设:新 kernel 假设 token 数据和 scale 在 block 内为平面布局,但实际 flashmla_sparse.py 使用 interleaved 布局(shape [num_blocks, block_size, 584]),导致指针算术错误。
  • FP8 类型分歧:使用了 NVIDIA 的 tl.float8e4nv,而 ROCm 应使用 tl.float8e4m3fnuz
  • decode kernel 两遍加载:对 topk 索引遍历两次,重复加载 KV cache 行,建议使用在线 softmax 优化。
    此外,tjtanaa 建议将新 backend 移到独立文件(已采纳),AndreasKaratzas 询问 gfx942 兼容性,whx-sjtu 回复仅测试了 gfx950,未处理 fnuz 格式。最终 PR 获得两个 APPROVAL 后合入。

实现拆解

  1. 新增 ROCm 专用 attention backend:创建 vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse_dsv4.py,定义 DeepseekV4ROCMAiterMLASparseBackendDeepseekV4ROCMAiterMLASparseImpl,分别继承自 FlashMLASparseBackendSparseMLAAttentionImpl,重写 get_impl_clsforward 等方法以调用 Triton kernel。

  2. 在算子库中实现 Triton kernel:在 vllm/v1/attention/ops/rocm_aiter_mla_sparse.py 中新增多个 Triton JIT kernel,包括 _sparse_attn_prefill_ragged_kernel_sparse_attn_decode_ragged_kernelbuild_ragged_indices_from_densecompute_global_topk_ragged_indices_and_indptr 等,实现 FP8 量化 KV cache 的加载、解量化、稀疏索引、softmax 和输出计算。

  3. 修改模型层以分平台选择 backend:在 vllm/model_executor/layers/deepseek_v4_attention.py 中,get_attn_backend 方法增加 ROCm 分支,返回新 backend;移除 _forward_decode_forward_prefill 中旧的 ROCm 特判分支,统一调用 backend 实现。

  4. 调整现有 backend 接口:在 flashmla_sparse.py 中泛化 get_impl_cls 返回值类型;在 sparse_swa.py 中为 get_builder_cls 增加 ROCm 平台分支,使用新 backend 的 metadata builder。

  5. 新增单元测试:创建 tests/kernels/attention/test_rocm_triton_attn_dsv4.py,包含 test_compute_global_topk_ragged_indices_and_indptrtest_sparse_attn_prefill_ragged_kernel 等测试,对比 Triton kernel 输出与 PyTorch reference 实现的一致性,并添加 FP8 cache 打包/读取辅助函数。

文件 模块 状态 重要度
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse_dsv4.py 稀疏 MLA 后端 added 9.08
tests/kernels/attention/test_rocm_triton_attn_dsv4.py 测试 added 7.76
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py 算子库 modified 7.22
vllm/model_executor/layers/deepseek_v4_attention.py 注意力层 modified 6.94
vllm/v1/attention/backends/mla/flashmla_sparse.py 稀疏 MLA modified 5.06
vllm/v1/attention/backends/mla/sparse_swa.py 稀疏 SWA modified 4.95

关键符号

compute_global_topk_ragged_indices_and_indptr _sparse_attn_prefill_ragged_kernel _sparse_attn_decode_ragged_kernel build_ragged_indices_from_dense combine_topk_swa_indices_ragged

关键源码片段

vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse_dsv4.py core-logic

新增的 ROCm 专用 attention backend,封装了 Triton kernel 调用和 metadata 构建,是整个 PR 的核心。

# SPDX-License-Identifier: Apache-2.0
# (c) vLLM contributorsimport torch
from vllm.triton_utils import tl, triton
​
​
def _build_indptr_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
    """从每个 token 的有效 topk 长度构建 indptr 数组。"""
    lengths = lengths.to(dtype=torch.int32).contiguous()
    indptr = torch.zeros(lengths.shape[0] + 1, dtype=torch.int32, device=lengths.device)
    torch.cumsum(lengths, dim=0, out=indptr[1:])
    return indptr
​
​
@triton.jit
def _compute_topk_lens_kernel(
    topk_lens_ptr,
    topk_indices_ptr,
    topk_indices_stride,
    topk,
    is_valid_token_ptr,
    TRITON_BLOCK_SIZE: tl.constexpr,
):
    """Triton kernel:计算每个 token 有效 topk 索引的数量。"""
    token_idx = tl.program_id(0)
    is_valid_token = tl.load(is_valid_token_ptr + token_idx)
    count = tl.zeros((), dtype=tl.int32)
    for i in range(0, topk, TRITON_BLOCK_SIZE):
        offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
        mask = offset < topk
        local_idx = tl.load(
            topk_indices_ptr + token_idx * topk_indices_stride + offset,
            mask=mask,
            other=-1,
        )
        count += tl.sum((local_idx >= 0).to(tl.int32), axis=0)
    tl.store(topk_lens_ptr + token_idx, tl.where(is_valid_token, count, 0))
​
​
def compute_global_topk_ragged_indices_and_indptr(
    topk_indices: torch.Tensor,
    token_to_req_indices: torch.Tensor,
    block_table: torch.Tensor,
    block_size: int,
    is_valid_token: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """将稠密的 topk 索引转换为 ragged 布局,
    同时通过 block_table 将局部索引映射为全局 slot ID。"""
    topk_indices = topk_indices.reshape(topk_indices.shape[0], -1).contiguous()
    num_tokens = topk_indices.shape[0]
    topk = topk_indices.shape[1]
​
    topk_lens = torch.empty(num_tokens, dtype=torch.int32, device=topk_indices.device)
    _compute_topk_lens_kernel[(num_tokens,)](
        topk_lens,
        topk_indices,
        topk_indices.stride(0),
        topk,
        is_valid_token,
        TRITON_BLOCK_SIZE=1024,
    )
​
    topk_indptr = _build_indptr_from_lengths(topk_lens)
    global_topk_ragged = torch.empty(
        num_tokens * topk,
        dtype=torch.int32,
        device=topk_indices.device,
    )
    if global_topk_ragged.numel() > 0:
        block = 128
        _pack_global_topk_ragged_kernel[(num_tokens, triton.cdiv(topk, block))](
            global_topk_ragged,
            topk_indptr,
            topk_indices,
            topk_indices.stride(0),
            token_to_req_indices,
            block_table,
            block_table.stride(0),
            block_size,
            topk,
            BLOCK_SIZE=block,
        )
    return global_topk_ragged, topk_indptr, topk_lens

评论区精华

KV cache 布局假设与指针算术错误 正确性

gemini-code-assist[bot] 指出新 Triton kernel 中的指针算术假设 KV cache 为平面布局,但实际 flashmla_sparse.py 中 cache shape 是 [num_blocks, block_size, 584],每个 token 的数据和 scale 是连续的,而非平面化,导致读取错误数据。

结论:未在 thread 中看到直接回复或修正,但 PR 最终被批准合并,可能已在后续提交中修复或认为当前实现已正确处理。 · 已解决

FP8 类型使用 NVIDIA 专有类型 float8e4nv 正确性

gemini-code-assist[bot] 指出 kernel 中使用 tl.float8e4nv 是 NVIDIA 类型,ROCm 应使用 float8e4m3fnuz,否则可能编译错误或结果错误。

结论:未看到直接回复,但 PR 合入,可能已修改为 ROCm 兼容类型或作者认为当前架构不受影响。 · 已解决

将新 backend 移到独立文件 设计

tjtanaa 建议将新的 backend 从 rocm_aiter_mla_sparse.py 移到独立的 rocm_aiter_mla_sparse_dsv4.py。whx-sjtu 回复 'sure'。

结论:已采纳,最终 PR 中 backend 实现在新文件中。 · 已解决

gfx942 兼容性问题 question

AndreasKaratzas 询问新的 Triton kernel 是否支持 gfx942(MI300)。whx-sjtu 回复只测试了 gfx950,未处理 fnuz 格式,可能不兼容。

结论:未解决,建议后续支持 fnuz 格式以便在 gfx942 上运行。 · unresolved

风险与影响

  • 兼容性风险:新 Triton kernel 仅针对 AMD MI355X (gfx950) 验证,对于其他 ROCm 架构(如 gfx942)可能因未处理 fnuz FP8 格式而功能异常。
  • 正确性风险:review 指出的指针布局假设可能与实际 KV cache 存储格式不一致,若未修正会导致推理结果错误(需确认是否在合入前修复)。
  • 性能风险:decode kernel 采用两遍加载方案,未使用在线 softmax 优化,可能未充分利用内存带宽。
  • 维护风险:新增大量 Triton kernel 代码依赖 ROCm 特定算子库,需要专人维护。
  • 用户:ROCm 上部署 DeepSeek V4 的用户将获得显著的性能提升(kernel 时间减少 <100μs),但仅限 gfx950 架构;gfx942 用户需验证兼容性。
  • 系统:新增了约 1.8k 行代码,扩展了 attention backend 体系;修改了 model layer 的选择逻辑,使平台特化的实现路径更清晰。
  • 团队:需要维护 ROCm 专用 Triton kernel,并跟踪上游 FlashMLA 接口变更。
仅支持 gfx950 FP8 类型分歧 kernel 布局假设 decode 双 pass

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论