Prhub

#25333 perf(mla): hybrid Triton fused cat+FP8-quantize for MLA chunked-prefill K/V

原始 PR 作者 ch-wan 合并时间 2026-05-16 01:51 文件变更 4 提交数 6 评论 2 代码增减 +641 / -12

执行摘要

MLA chunked-prefill K/V 融合 cat+FP8 量化单核,最高 10x 加速

在MLA chunked-prefill路径中,构建K张量需要先分配BF16中间张量,再执行k_nope和k_pe的拼接,然后分别对K和V进行per-tensor FP8量化,总共三次分离操作。每次操作都会通过全局内存产生中间数据,浪费显存带宽和启动开销。融合为一个内核可以消除这两次中间全局内存写入,同时允许在寄存器级别完成类型转换。如PR body所述:“The straightforward implementation is three sequential ops ... which dispatches the same (s, num_heads) work three times, round-trips intermediate BF16/FP16 values through gmem between the concat and the quantize, and cannot share PDL hand-off.”

建议精读。该PR不仅带来了显著的性能提升(5.5×-10×),还在以下方面具有工程借鉴价值:

  • 通过Triton内核融合消除中间全局内存数据,是GPU性能优化的典型手法。
  • 混合调度器针对不同batch size选择网格维度和配置,体现了对GPU计算/内存行为分区的深刻理解。
  • 通过duck-typing挂钩集成,无需修改现有注意力后端,保持了接口清晰和回退安全。
  • 完整的性能调优过程和Benchmark表格可作为同类优化的参考模板。
讨论亮点

本PR未触发人工审查评论,仅有一条gemini-code-assist的配额警告和作者触发的/tag-and-rerun-ci命令。PR body中包含了极其详尽的性能调优说明和完整数据表格,展示了从朴素Triton到混合调度器每次迭代的改进幅度,以及最终在GB300上相比PyTorch eager 5.5×-10×的加速表,体现了作者对Triton内核设计的深入考量。

实现拆解

  1. 设计两个Triton内核变体_v0_kernel(2D网格 (s, h))适合小batch,_v1_flat_kernel(1D扁平网格 s * num_heads)适合大batch。两者均在单一内核内完成加载BF16 k_nope/k_pe/v、乘以scale inverse、转换为FP8、并写入连续K输出缓冲区(k_nope在前,k_pe在后)和V输出缓冲区。
  2. 实现混合调度器 _pick_kernel:基于GB300预调参,根据s选择变体和配置(BLOCK_S/num_warps/num_stages/PDL使能)。s ≤ 8时使用_v0_kernel,后者内置4组配置带;s > 8时使用_v1_flat_kernel,配置按总元素数 s * num_heads 分为3个区间。PDL在架构支持时启用(可为小型batch带来40%以上性能提升)。
  3. 集成到MHA chunked-prefixforward_mha.py):在_chunked_prefix_attn_mha中通过getattr(backend, "pack_prefix_chunk_kv", None)进行duck-typing检测。当后端提供该挂钩时,直接调用融合内核(mla_kv_pack_quantize_fp8)替代原torch.empty + slice赋值 + to(fp8)路径,kv_b_proj的输入类型也相应调整为BF16;否则保持原行为不变。新增_resolve_attn_backend辅助函数解包TboAttnBackend包装。
  4. 配套基准测试bench_mla_kv_pack_quantize_fp8.py):使用triton.testing.Benchmark对比混合内核、朴素Triton基线和PyTorch eager实现,覆盖BS 1~16384,输出微秒级延迟。注册为CI基准stage-b-kernel-benchmark-1-gpu-large
  5. 配套正确性测试test_mla_kv_pack_quantize_fp8.py):参数化测试覆盖4种dtype对、6种head维度组合、8种batch size、4种head count,共130个case。参考实现_ref在BF16中构建完整K张量后执行FP8量化,与融合内核结果在rtol=1e-2/atol=0.5下比对。
文件 模块 状态 重要度
python/sglang/jit_kernel/mla_kv_pack_quantize_fp8.py JIT 内核 added 9.08
python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py 注意力 modified 6.97
python/sglang/jit_kernel/tests/test_mla_kv_pack_quantize_fp8.py JIT 内核 added 7.46
python/sglang/jit_kernel/benchmark/bench_mla_kv_pack_quantize_fp8.py JIT 内核 added 8.92

关键符号

_v0_kernel _v1_flat_kernel _pick_kernel mla_kv_pack_quantize_fp8 _resolve_attn_backend

关键源码片段

python/sglang/jit_kernel/mla_kv_pack_quantize_fp8.py core-logic

新增融合 Triton 内核,包含两种网格变体和混合调度器,是本次 PR 的核心性能优化。

@triton.jit
def _v0_kernel(
    k_nope_ptr, k_pe_ptr, v_ptr,
    k_out_ptr, v_out_ptr,
    k_scale_inv, v_scale_inv,
    s_total,
    k_nope_stride_t, k_nope_stride_h,
    k_pe_stride_t,
    v_stride_t, v_stride_h,
    k_out_stride_t, k_out_stride_h,
    v_out_stride_t, v_out_stride_h,
    QK_NOPE: tl.constexpr, QK_ROPE: tl.constexpr, V_HEAD: tl.constexpr,
    FP8_DTYPE: tl.constexpr, BLOCK_S: tl.constexpr, ENABLE_PDL: tl.constexpr,
):
    # 2D grid: pid_s over token blocks, pid_h over heads
    pid_s = tl.program_id(0)
    pid_h = tl.program_id(1)
    t_idx = pid_s * BLOCK_S + tl.arange(0, BLOCK_S)
    t_mask = t_idx < s_total
    nope_idx = tl.arange(0, QK_NOPE)
    rope_idx = tl.arange(0, QK_ROPE)
    v_idx = tl.arange(0, V_HEAD)
​
    if ENABLE_PDL:
        tl.extra.cuda.gdc_wait()
​
    # load k_nope (s, h, QK_NOPE)
    nope_off = t_idx[:, None] * k_nope_stride_t + pid_h * k_nope_stride_h + nope_idx[None, :]
    k_nope = tl.load(k_nope_ptr + nope_off, mask=t_mask[:, None])
​
    # load k_pe (s, 1, QK_ROPE) broadcast automatically
    pe_off = t_idx[:, None] * k_pe_stride_t + rope_idx[None, :]
    k_pe = tl.load(k_pe_ptr + pe_off, mask=t_mask[:, None])
​
    # load v
    v_off = t_idx[:, None] * v_stride_t + pid_h * v_stride_h + v_idx[None, :]
    v = tl.load(v_ptr + v_off, mask=t_mask[:, None])
​
    # FP8 quantize: promote to FP32, multiply by scale_inv, cast to FP8
    k_nope_fp8 = (k_nope.to(tl.float32) * k_scale_inv).to(FP8_DTYPE)
    k_pe_fp8 = (k_pe.to(tl.float32) * k_scale_inv).to(FP8_DTYPE)
    v_fp8 = (v.to(tl.float32) * v_scale_inv).to(FP8_DTYPE)
​
    # store K: [:, :QK_NOPE] = k_nope, [:, QK_NOPE:] = k_pe
    k_out_base = t_idx[:, None] * k_out_stride_t + pid_h * k_out_stride_h
    tl.store(k_out_ptr + k_out_base + nope_idx[None, :], k_nope_fp8, mask=t_mask[:, None])
    tl.store(k_out_ptr + k_out_base + QK_NOPE + rope_idx[None, :], k_pe_fp8, mask=t_mask[:, None])
​
    # store V
    v_out_off = t_idx[:, None] * v_out_stride_t + pid_h * v_out_stride_h + v_idx[None, :]
    tl.store(v_out_ptr + v_out_off, v_fp8, mask=t_mask[:, None])
​
    if ENABLE_PDL:
        tl.extra.cuda.gdc_launch_dependents()
def _pick_kernel(s: int, num_heads: int) -> Tuple[str, dict]:
    """Tuned on GB300, DSv3 dims, BF16 -> FP8 e4m3."""
    if s <= 2:
        # launch-overhead bound: minimal config avoids warp waste
        return ("v0", {"BLOCK_S": 1, "num_warps": 1, "num_stages": 2, "ENABLE_PDL": True})
    # v0 uses 4 bands for s <= 8, v1_flat uses 3 bands for s > 8
    # (full tuning table omitted for brevity, see source)
    ...
python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py data-contract

通过后端挂钩机制集成融合内核,使用 duck-typing 检测 `pack_prefix_chunk_kv` 方法,无需改动现有注意力后端。

def _resolve_attn_backend(forward_batch: ForwardBatch):
    # TboAttnBackend is a wrapper; unwrap to get the real backend
    backend = forward_batch.attn_backend
    if isinstance(backend, TboAttnBackend):
        backend = backend.primary
    return backend
​
​
def _chunked_prefix_attn_mha(
    self: DeepseekV2AttentionMLA,
    q: torch.Tensor,
    accum_output: torch.Tensor,
    accum_lse: torch.Tensor,
    forward_batch: ForwardBatch,
) -> torch.Tensor:
    # try to obtain the optional pack hook from backend
    backend = _resolve_attn_backend(forward_batch)
    pack_fn = getattr(backend, "pack_prefix_chunk_kv", None)
    # kv_b_proj needs BF16 input; if pack_fn exists, fetch latent in BF16
    kv_a_dtype = torch.bfloat16 if pack_fn is not None else q.dtype
​
    assert forward_batch.num_prefix_chunks is not None
    for i in range(forward_batch.num_prefix_chunks):
        forward_batch.set_prefix_chunk_idx(i)
        kv_indices = forward_batch.prefix_chunk_kv_indices[i]
​
        # fetch latent cache in BF16 (or q.dtype if no pack)
        kv_a_normed, k_pe = self._get_mla_kv_buffer(
            kv_indices, kv_a_dtype, forward_batch
        )
        kv = self.kv_b_proj(kv_a_normed)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        v = kv[..., self.qk_nope_head_dim :]
        k_nope = kv[..., : self.qk_nope_head_dim]
​
        if pack_fn is not None:
            # fused cat + FP8 quantize, backend owns the kernel choice
            k, v = pack_fn(k_nope, k_pe, v)
        else:
            # original three-step path (BF16 concatenation + 2x FP8 cast)
            k = torch.empty(
                (k_nope.shape[0], self.num_local_heads,
                 self.qk_nope_head_dim + self.qk_rope_head_dim),
                dtype=v.dtype, device=v.device,
            )
            k[..., :self.qk_nope_head_dim] = k_nope
            k[..., self.qk_nope_head_dim:] = k_pe
​
        output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        # ... accumulate output

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

  1. 硬件特定调优_pick_kernel的启发式参数目前基于GB300(SM103)调优,在H100、B200或其他架构上可能非最优。但保守调度(回退到朴素Triton)仍可保证正确性,只是性能可能不达预期。
  2. PDL架构依赖:PDL(Pipeline Deferral)仅在支持gdc_wait/gdc_launch_dependents的架构上启用,通过is_arch_support_pdl()动态检测,不支持的设备自动跳过,无风险。
  3. 核心路径变更_chunked_prefix_attn_mha是MLA prefill的关键路径,修改通过后端挂钩确保默认行为不变。没有后端实现pack_prefix_chunk_kv时与原逻辑完全一致,无回归风险。
  4. 量化精度一致:融合内核使用与原始per-tensor量化一致的to(float32) * scale_inv -> cast(FP8_DTYPE)路径,经130个测试case验证,精度风险低。
  5. cudagraph兼容性:PDL使能时使用tl.extra.cuda API,若cudagraph不支持PDL,调度器会自动禁用PDL。

用户影响:所有使用DeepSeek MLA模型(V2/V3/V4)的chunked-prefill场景将在FP8量化路径下获得即时延迟降低。小batch(1-4 tokens)加速约10倍,中等batch(384 tokens)加速约9倍,大batch(≥2048)加速约5.5倍。无需用户手动配置。
系统影响:减少每个prefill chunk的内核启动次数(从3次降至1次),减少全局内存中间数据写入量(每个chunk约 2 * s * num_heads * (qk_nope+qk_rope+v_head) * sizeof(fp8)字节)。有助于提升GPU利用率和整体吞吐。
团队影响:提供了可复用的JIT内核融合模式和数据契约扩展点。未来类似操作(如其他量化的拼接)可复用相同调度框架。且bench_mla_kv_pack_quantize_fp8.py已注册CI基准,可监控后续修改对性能的影响。

核心路径变更 硬件特定调优 PDL 架构依赖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论