Prhub

#24925 [attn backend] Integrate tokenspeed_mla prefill/decode kernels (fp8 kv cache, blackwell)

原始 PR 作者 Qiaolin-Yu 合并时间 2026-05-14 08:36 文件变更 11 提交数 7 评论 7 代码增减 +462 / -92

执行摘要

集成 tokenspeed_mla Blackwell MLA 内核后端

在 Blackwell GPU (SM100) 上利用 tokenspeed_mla 提供的高效 CuTe DSL 内核来加速 MLA 模型的 prefill 和 decode 阶段,支持 FP8 KV 缓存,以提升推理性能。

建议阅读 tokenspeed_mla_backend.py 了解子类化扩展点设计,学习如何通过重构 trtllm_mla_backend.py 实现内核调度可替换。关注 tokenspeed_mla 包的安装与验证流程。未来可基于此模式集成更多 CuTe DSL 内核。

讨论亮点
  1. 注释迁移 (reviewer: kpham-sgl) 建议将原 forward_decode 中关于 BMM1 缩放因子的注释移至新抽取的 _compute_decode_bmm1_scale 方法上,作者 Qiaolin-Yu 已处理。
  2. 变更风险 (reviewer: Fridge003) 询问修改 trtllm_mla_backend.py 是否有风险,作者回应“主要是包装函数,应该没问题”。

实现拆解

  1. 新增 tokenspeed_mla 后端:创建 python/sglang/srt/layers/attention/tokenspeed_mla_backend.py,包含 _get_tokenspeed_workspace 工具函数和 TokenspeedMLABackend 类(继承 TRTLLMMLABackend),覆盖 _run_decode_kernel_run_prefill_kernel 以调用 tokenspeed_mla 内核。该类构造函数中校验 FP8 数据类型和 page_size,并在初始化时预编译 prefill 内核变体以避免首次请求超时。
  2. 重构 trtllm_mla_backend:在 trtllm_mla_backend.py 中将 forward_decode 中实际内核调用抽取为 _run_decode_kernel_run_prefill_kernel 方法,并提取 _compute_decode_bmm1_scale 用于计算 BMM1 缩放因子,使子类可以轻松替换内核实现而不影响外围逻辑。
  3. 注册新后端:在 attention_registry.py 通过 @register_attention_backend("tokenspeed_mla") 注册工厂函数 create_tokenspeed_mla_backend。在 draft_utils.pycreate_decode_backendcreate_draft_extend_backend 的 backend_map 中添加 "tokenspeed_mla" 键,并实现对应创建方法(返回 TokenspeedMLABackendTokenspeedMLAMultiStepDraftBackend)。
  4. 配置与约束:在 server_args.py_handle_attention_backend_compatibility 中添加新后端的校验:仅支持 Blackwell 架构、强制 page_size 为 32 或 64(自动调整)、要求 kv_cache_dtype 为 fp8_e4m3。同时在 models/deepseek_common/attention_backend_handler.py 注册 handler(委托给 handle_attention_trtllm_mla)并在 forward_mla.py 中添加支持。
  5. 工具与依赖:在 utils/common.py 添加 is_tokenspeed_mla_available 函数用于检测外部包;在 model_runner.pyutils.py 中做微小适配;在 pyproject.toml 添加 tokenspeed_mla 可选依赖。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/tokenspeed_mla_backend.py 注意力后端 added 9.25
python/sglang/srt/layers/attention/trtllm_mla_backend.py MLA 后端 modified 8.58
python/sglang/srt/server_args.py 配置层 modified 6.45
python/sglang/srt/layers/attention/attention_registry.py 注册中心 modified 6.5
python/sglang/srt/speculative/draft_utils.py 推测解码 modified 7.29

关键符号

_get_tokenspeed_workspace TokenspeedMLABackend.__init__ TokenspeedMLABackend._run_decode_kernel TokenspeedMLABackend._run_prefill_kernel TRTLLMMLABackend._compute_decode_bmm1_scale TRTLLMMLABackend._run_decode_kernel TRTLLMMLABackend._run_prefill_kernel create_tokenspeed_mla_backend _create_tokenspeed_mla_decode_backend _create_tokenspeed_mla_prefill_backend is_tokenspeed_mla_available

关键源码片段

python/sglang/srt/layers/attention/tokenspeed_mla_backend.py dependency-wiring

新增的后端实现,核心文件,包含 TokenspeedMLABackend 和 TokenspeedMLAMultiStepDraftBackend 类,以及 workspace 管理与内核调用。

# tokenspeed_mla_backend.py · 新增于 PR#24925
# 为 Blackwell GPU 提供 tokenspeed-mla CuTe DSL 注意力后端 (FP8 KV cache)from __future__ import annotations# 子类化 TRTLLMMLABackend 并仅覆盖 _run_decode_kernel / _run_prefill_kernel
# 所有元数据、KV 缓存布局、CUDA 图流水线、FP8 量化 /rope、
# draft-extend 填充和 chunked-prefix 调度均从父类继承import torch
from sglang.srt.layers.attention.trtllm_mla_backend import (
    TRTLLMMLABackend,
    TRTLLMMLAMultiStepDraftBackend,
    _quantize_fp8_qkv,
)
from sglang.srt.utils import is_tokenspeed_mla_availableif is_tokenspeed_mla_available():
    import tokenspeed_mla# 全局 workspace 缓存,按设备存储
_g_tokenspeed_workspace: dict[torch.device, torch.Tensor] = {}# 最大 q_len 为 8 以覆盖 EAGLE3 的 4 个 draft token 并留余量
_TOKENSPEED_MAX_Q_LEN = 8
​
​
def _get_tokenspeed_workspace(
    device: torch.device, num_heads: int, kv_lora_rank: int
) -> torch.Tensor:
    """获取或分配 tokenspeed_mla_decode 所需的 workspace。"""
    # 计算需求 : num_sms * num_heads * max_q_len * (kv_lora_rank + 1) * 4 (float32)
    needed = (
        tokenspeed_mla.get_num_sm(device)
        * num_heads
        * _TOKENSPEED_MAX_Q_LEN
        * (kv_lora_rank + 1)
        * 4
    )
    existing = _g_tokenspeed_workspace.get(device)
    if existing is None or existing.numel() < needed:
        _g_tokenspeed_workspace[device] = torch.empty(
            needed, dtype=torch.int8, device=device
        )
    return _g_tokenspeed_workspace[device]
​
​
class TokenspeedMLABackend(TRTLLMMLABackend):
    """tokenspeed-mla 后端 (Blackwell SM100, FP8 KV cache)。"""
​
    def __init__(
        self,
        model_runner,
        skip_prefill: bool = False,
        kv_indptr_buf=None,
        q_indptr_decode_buf=None,
    ):
        super().__init__(
            model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf
        )
        # 强制 FP8 数据类型
        if self.data_type != torch.float8_e4m3fn:
            raise ValueError(
                "tokenspeed_mla backend requires --kv-cache-dtype fp8_e4m3, "
                f"got data_type={self.data_type}."
            )
        if self.page_size not in (32, 64):
            raise ValueError(
                "tokenspeed_mla backend requires page_size in {32, 64}, "
                f"got page_size={self.page_size}."
            )
        self._tokenspeed_workspace: Optional[torch.Tensor] = None
​
        # 预编译 prefill 内核变体,避免首次请求触发调度器看门狗超时
        if is_tokenspeed_mla_available():
            _compile_prefill_kernel = tokenspeed_mla.mla_prefill._compile_prefill_kernel
            _compiled_kernels = tokenspeed_mla.mla_prefill._compiled_kernels
            # ... 遍历 causal 和 return_lse 组合并编译
python/sglang/srt/layers/attention/trtllm_mla_backend.py core-logic

重构的基础后端,提取 _run_decode_kernel、_run_prefill_kernel 和 _compute_decode_bmm1_scale 方法,作为子类可替换的扩展点。

# trtllm_mla_backend.py · 变更于 PR#24925
# 提取 BMM1 缩放因子计算和内核调度钩子def _compute_decode_bmm1_scale(self, layer: RadixAttention) -> float:
    """计算 BMM1 的最终缩放因子: q_scale * k_scale * softmax_scale。    k_scale 仅在 KV cache 为 FP8 时取自 checkpoint,
    否则为 1.0;若 checkpoint 含 k_scale 但 dtype 非 FP8 则发出告警。
    """
    q_scale = 1.0
    if self.data_type == torch.float8_e4m3fn:
        k_scale = (
            layer.k_scale_float
            if getattr(layer, "k_scale_float", None) is not None
            else 1.0
        )
    else:
        if getattr(layer, "k_scale_float", None) is not None:
            logger.warning_once(
                "Checkpoint has k_scale but KV cache dtype is not FP8. "
                "Ignoring k_scale for BMM1 (k_scale=%.4f, kv_dtype=%s).",
                layer.k_scale_float,
                self.data_type,
            )
        k_scale = 1.0
    return q_scale * k_scale * layer.scaling
​
​
def _run_decode_kernel(
    self,
    query: torch.Tensor,
    kv_cache: torch.Tensor,
    block_tables: torch.Tensor,
    seq_lens: torch.Tensor,
    max_seq_len: int,
    layer: RadixAttention,
) -> torch.Tensor:
    """Decode 内核调用钩子,子类可覆盖以替换具体实现。"""
    bmm1_scale = self._compute_decode_bmm1_scale(layer)
    seq_lens_i32 = (
        seq_lens if seq_lens.dtype == torch.int32 else seq_lens.to(torch.int32)
    )
    return flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
        query=query,
        kv_cache=kv_cache,
        workspace_buffer=self.workspace_buffer,
        qk_nope_head_dim=self.qk_nope_head_dim,
        kv_lora_rank=self.kv_lora_rank,
        qk_rope_head_dim=self.qk_rope_head_dim,
        block_tables=block_tables,
        seq_lens=seq_lens_i32,
        max_seq_len=max_seq_len,
        bmm1_scale=bmm1_scale,
        skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR.get(),
    )

评论区精华

注释位置调整 documentation

kpham-sgl 建议将原 forward_decode 中的 BMM1 缩放注释移至新抽取的 _compute_decode_bmm1_scale 方法上。

结论:作者 Qiaolin-Yu 回复 'done',确认已迁移注释。 · 已解决

重构风险确认 设计

Fridge003 询问修改 trtllm_mla_backend.py 是否有风险。

结论:作者 Qiaolin-Yu 表示 'mostly just wrap some functions. so should be fine',认定无风险。 · 已解决

风险与影响

  • 外部依赖风险:需安装 tokenspeed_mla 包,可能面临版本兼容或 API 变动问题。
  • 硬件锁定:仅限 Blackwell GPU (SM100),其他 NVIDIA GPU 不可用。
  • 配置约束严格:强制 FP8 KV cache 且 page_size 固定为 32 或 64,使用不匹配时会报错或自动调整,可能引发用户困惑。
  • 重构影响trtllm_mla_backend.py 的重构理论上不影响已有 trtllm_mla 后端行为,但提取方法时可能引入回归(如缩放因子计算、参数传递错误),需依赖现有测试覆盖。
  • 缺失测试:本次未新增测试文件,新后端的正确性和性能验证依赖外部 benchmark。
  • 用户影响:仅影响在 Blackwell GPU 上使用 MLA 模型并选择 tokenspeed_mla 后端的用户,其他用户无影响。
  • 系统影响:新增可选外部依赖,运行时需 CUDA 12.8+ 和 Blackwell 架构。自动调整 page_size 的行为可能覆盖用户显式设置。
  • 团队影响:需维护与 tokenspeed_mla 的集成,关注上游内核更新;重构后的 trtllm_mla 后端更易扩展,有利于未来集成其他内核。
依赖外部 tokenspeed_mla 包 Blackwell 硬件限定 FP8 KV cache 约束 缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论