Prhub

#25311 perf(mla): TMA bulk-store set_mla_kv_buffer (up to 12× over baseline)

原始 PR 作者 ch-wan 合并时间 2026-05-15 09:23 文件变更 5 提交数 1 评论 4 代码增减 +678 / -5

执行摘要

优化 MLA KV 缓存写入,性能提升最高 12 倍

MLA paged-KV 散射写入(set_mla_kv_buffer)原本是一个 1D Triton 内核(BLOCK=128, grid (n_loc, ceil(total_dim/BLOCK))),在批量较小时表现尚可,但随着 n_loc 线性退化——在 GB300 上 bs=16384 时达到 83.5μs。对于 DeepSeek-V4 prefill(61 层 × 每步数千个 loc)这是层时间中相当可观的部分。

值得精读。该 PR 展示了 GPU 内核优化的完整工程实践:从瓶颈识别、多种实现方案对比、自动调度到测试和基准覆盖,并处理了 TMA 硬件特有的正确性细节。可学习其设计决策和阈值调优方法。

讨论亮点

Reviewer BBuf 提出了两条评论:

  • 向量化内存访问宽度(文件 set_mla_kv_buffer.cuh 第 57 行):指出 Blackwell/GB300 支持 32 字节向量化访问,暗示可以进一步利用。目前内核中向量宽度基于对齐推导,可能未充分使用 32 字节。结论:作者未回复,PR 已合并,可能认为当前实现已足够或需要后续迭代。
  • 减少重复计算(第 100 行):locgmem_dst 地址计算仅被 lane 0 用于 TMA 操作,建议移至 lane-0 分支内以避免其他 lane 的冗余工作。结论:同样未解决,但该优化可能对性能影响较小。

实现拆解

  1. 新增 JIT CUDA TMA 批量存储内核python/sglang/jit_kernel/csrc/elementwise/set_mla_kv_buffer.cuh):每个 warp 将一行 (nope, rope) 加载到共享内存,然后 lane 0 发出 cp.async.bulk.global.shared::cta 指令将整行散射写入 kv_buffer。对批量大(≥768)的场景,每个 CTA 可处理 4-8 行,远低于每行一个 CTA 的开销。

  2. Triton 路径优化python/sglang/srt/mem_cache/utils.py):Triton 内核的 BLOCK 从固定 128 改为 next_pow2(nope_dim + rope_dim),使得每个 CTA 覆盖一整行,消除了边界分支和额外的 CTA 发散。小幅批量(<768)时性能提升 1.01-3.11 倍。

  3. 自动调度与兼容性封装python/sglang/jit_kernel/set_mla_kv_buffer.py):新增 set_mla_kv_buffer 函数作为 TMA 路径的入口,配合 can_use_set_mla_kv_buffer 检查行宽度对齐和架构支持。原 set_mla_kv_buffer_triton 函数保留名称,内部根据 n_loc 和架构条件自动选择 TMA 或 Triton 路径。

  4. 测试与基准配套

    • 测试python/sglang/jit_kernel/tests/test_set_mla_kv_buffer.py):覆盖多种数据类型(bf16, fp16)、形状(含 FP8 NSA 字节布局)、批量大小(含空 loc)和 loc 数据类型,共 55 个测试。
    • 基准python/sglang/jit_kernel/benchmark/bench_set_mla_kv_buffer.py):提供 wrapper(自动调度)、jit_tma(直接 TMA)、triton(原 Triton 基线)三组对比,在 CI 中注册为 stage-b-kernel-benchmark。
文件 模块 状态 重要度
python/sglang/jit_kernel/set_mla_kv_buffer.py JIT 内核 added 8.93
python/sglang/jit_kernel/benchmark/bench_set_mla_kv_buffer.py 基准测试 added 8.64
python/sglang/jit_kernel/tests/test_set_mla_kv_buffer.py 测试 added 7.83
python/sglang/srt/mem_cache/utils.py 缓存层 modified 7.0
python/sglang/jit_kernel/csrc/elementwise/set_mla_kv_buffer.cuh JIT 内核 added 6.04

关键符号

set_mla_kv_buffer can_use_set_mla_kv_buffer _pick_num_warps _jit_set_mla_kv_buffer_module set_mla_kv_buffer_triton benchmark _triton_baseline test_set_mla_kv_buffer_correctness test_set_mla_kv_buffer_loc_dtypes test_set_mla_kv_buffer_uint8_byte_layout test_set_mla_kv_buffer_empty_loc test_can_use_set_mla_kv_buffer

关键源码片段

python/sglang/jit_kernel/set_mla_kv_buffer.py core-logic

核心 Python 包装器,定义 TMA 路径的入口函数和兼容性检查,是整个优化的调度枢纽。

"""JIT TMA bulk-store path for ``set_mla_kv_buffer``.Each warp scatter-writes one item's (nope, rope) row via a single
``cp.async.bulk.global.shared::cta`` store. Requires SM90+ (Hopper or later)
for the TMA bulk-store hardware. The host-side wrapper in
``sglang.srt.mem_cache.utils`` falls back to a Triton kernel for older arches.
"""from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import torch
from sglang.jit_kernel.utils import cache_once, is_arch_support_pdl, load_jit, make_cpp_argsif TYPE_CHECKING:
    from tvm_ffi.module import Modulelogger = logging.getLogger(__name__)
​
​
@cache_once
def _jit_set_mla_kv_buffer_module(
    nope_bytes: int, rope_bytes: int, use_pdl: bool
) -> Module:
    # 构建编译参数并加载 JIT CUDA 内核
    args = make_cpp_args(nope_bytes, rope_bytes, use_pdl)
    return load_jit(
        f"set_mla_kv_buffer_{nope_bytes}_{rope_bytes}",
        *args,
        cuda_files=["elementwise/set_mla_kv_buffer.cuh"],
        cuda_wrappers=[
            ("set_mla_kv_buffer", f"SetMlaKVBufferKernel<{args}>::run"),
        ],
    )
​
​
@cache_once
def can_use_set_mla_kv_buffer(nope_bytes: int, rope_bytes: int) -> bool:
    # TMA 要求行总字节数是 16 的倍数,且每部分字节数是 4 的倍数
    if nope_bytes % 4 != 0 or rope_bytes % 4 != 0:
        return False
    if (nope_bytes + rope_bytes) % 16 != 0:
        return False
    try:
        _jit_set_mla_kv_buffer_module(nope_bytes, rope_bytes, is_arch_support_pdl())
        return True
    except Exception as e:
        logger.warning("Failed to load JIT kernel: %s", e)
        return False
​
​
def _pick_num_warps(n_loc: int) -> int:
    # 在 GB300 上调优:小批量更多 warp 以利用 SM,大批量减少 warp 以分摊 bulk-group commit 的开销
    return 4 if n_loc <= 768 else 8
​
​
def set_mla_kv_buffer(
    kv_buffer: torch.Tensor,
    loc: torch.Tensor,
    cache_k_nope: torch.Tensor,
    cache_k_rope: torch.Tensor,
    num_warps: int = 0,
) -> None:
    # 使用 TMA 批量存储写入 KV 缓冲区,仅在 SM90+ 架构上调用
    n_loc = loc.shape[0]
    if n_loc == 0:
        return
    src_nope = cache_k_nope.view(n_loc, -1) if cache_k_nope.dim() != 2 else cache_k_nope
    src_rope = cache_k_rope.view(n_loc, -1) if cache_k_rope.dim() != 2 else cache_k_rope
    buf = kv_buffer.view(kv_buffer.shape[0], -1) if kv_buffer.dim() != 2 else kv_buffer
    nope_bytes = src_nope.shape[-1] * src_nope.element_size()
    rope_bytes = src_rope.shape[-1] * src_rope.element_size()
    if num_warps <= 0:
        num_warps = _pick_num_warps(n_loc)
    module = _jit_set_mla_kv_buffer_module(nope_bytes, rope_bytes, is_arch_support_pdl())
    module.set_mla_kv_buffer(buf, loc, src_nope, src_rope, num_warps)
python/sglang/jit_kernel/tests/test_set_mla_kv_buffer.py test-coverage

正确定测试覆盖多种数据类型、形状、批量大小和边界情况,确保优化不引入错误。

import sys
import pytest
import torch
from sglang.jit_kernel.set_mla_kv_buffer import can_use_set_mla_kv_buffer, set_mla_kv_buffer
from sglang.jit_kernel.utils import get_ci_test_range
from sglang.test.ci.ci_register import register_cuda_ciregister_cuda_ci(est_time=30, suite="stage-b-kernel-unit-1-gpu-large")DEVICE = "cuda"
CACHE_SIZE = 4096# (nope_dim, rope_dim) pairs: standard MLA, MLA scale buffer, FP8 nope-extended layout
SHAPES = get_ci_test_range(
    [(512, 64), (512, 32), (256, 64), (128, 64), (528, 64)],
    [(512, 64), (528, 64)],
)
BATCH_SIZES = get_ci_test_range([1, 7, 64, 257, 1024], [1, 64, 1024])
​
​
def _ref(kv_buffer, loc, cache_k_nope, cache_k_rope):
    # 纯 PyTorch 参考实现:直接索引赋值
    nope_dim = cache_k_nope.shape[-1]
    n_loc = loc.shape[0]
    src_nope = cache_k_nope.reshape(n_loc, -1)
    src_rope = cache_k_rope.reshape(n_loc, -1)
    kv_view = kv_buffer.view(kv_buffer.shape[0], -1)
    kv_view[loc.long(), :nope_dim] = src_nope
    kv_view[loc.long(), nope_dim : nope_dim + src_rope.shape[-1]] = src_rope
​
​
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
def test_set_mla_kv_buffer_correctness(dtype, shape, batch_size):
    # 对比内核输出与参考实现,要求逐元素一致
    nope_dim, rope_dim = shape
    total_dim = nope_dim + rope_dim
    cache_k_nope = torch.randn((batch_size, 1, nope_dim), dtype=dtype, device=DEVICE)
    cache_k_rope = torch.randn((batch_size, 1, rope_dim), dtype=dtype, device=DEVICE)
    kv_buffer = torch.randn((CACHE_SIZE, 1, total_dim), dtype=dtype, device=DEVICE)
    kv_ref = kv_buffer.clone()
    loc = torch.randperm(CACHE_SIZE, device=DEVICE)[:batch_size]
    set_mla_kv_buffer(kv_buffer, loc, cache_k_nope, cache_k_rope)
    _ref(kv_ref, loc, cache_k_nope, cache_k_rope)
    assert torch.equal(kv_buffer, kv_ref)
​
​
@pytest.mark.parametrize("loc_dtype", [torch.int32, torch.int64])
def test_set_mla_kv_buffer_loc_dtypes(loc_dtype):
    # 确保两种 loc 数据类型均正常工作
    ... # 核心逻辑类似,忽略
​
​
def test_set_mla_kv_buffer_uint8_byte_layout():
    # 覆盖 FP8 NSA 字节布局(528 + 128 = 656 字节)
    ... # 核心逻辑类似,忽略
​
​
def test_set_mla_kv_buffer_empty_loc():
    # 空 loc 时不应修改缓冲区
    ...
​
​
def test_can_use_set_mla_kv_buffer():
    assert can_use_set_mla_kv_buffer(1024, 128) # bf16 (512,64)
    assert can_use_set_mla_kv_buffer(528, 128) # fp8 byte layout
    assert not can_use_set_mla_kv_buffer(13, 8) # not multiple of 4

评论区精华

Blackwell 32 字节向量化访问 性能

Reviewer BBuf 指出 Blackwell/GB300 支持 32 字节向量化内存访问,暗示当前内核可能未充分利用。

结论:作者未回复,PR 已合并;可能认为当前向量宽度已足够或待后续优化。 · COMMENTED

减少 loc 和地址计算的重复工作 性能

Reviewer BBuf 建议将 loc 加载和 gmem_dst 地址计算移至 lane-0 分支,避免其他 lane 的不必要计算。

结论:作者未回复或修改,PR 已合并;可能认为编译器优化可消除多余计算,或影响很小。 · COMMENTED

风险与影响

  1. GPU 架构依赖:TMA 内核需要 SM90+(Hopper/Blackwell),在不支持的架构上会回退到 Triton,但需确保 is_arch_support_pdl() 检测正确。目前该函数基于 CUDA capability 判断,风险可控。
  2. TMA 正确性条件:内核依赖 fence.proxy.async.shared::ctawait_group<0> 保证数据可见性。若 future CUDA 版本改变语义,可能出现写入不完整。已通过注释明确标记。
  3. 批量阈值调优:TMA 与 Triton 的分界阈值(768)基于 GB300 调优,在其他 GPU(如 H100)上可能非最优,但性能仍优于原基线。
  4. Blackwell 向量化未利用:review 指出可以支持 32 字节向量加载,当前仅用到 16/8/4 字节,可能留有微优化空间,但不导致功能错误。

对使用 DeepSeek-V4(或类似 MLA 架构)的用户,prefill 阶段性能显著提升,尤其是长序列或大批量场景。对系统,需 CUDA 12.4+ 和 TMA 支持的 GPU;代码向后兼容,不影响现有模型。对团队,新增了 JIT 内核模式范例,有助于后续类似优化。

GPU 架构依赖(SM90+) TMA 正确性要求(fence/wait_group) 阈值基于 GB300 调优 Blackwell 32 字节访问未启用

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论