Prhub

#40941 [Attention][TurboQuant] Share dequant buffers, eliminate float16_copy

原始 PR 作者 bhoomit 合并时间 2026-04-27 13:48 文件变更 4 提交数 4 评论 1 代码增减 +98 / -92

执行摘要

共享 TQ 去量化缓冲区并消除 float16_copy

使用 TurboQuant KV 缓存时,每个注意力层独立分配去量化缓冲区导致内存随层数和上下文长度线性增长,在 1M 上下文下需要额外 57GB 内存,严重限制了可服务的最长上下文。此外,decode 阶段的 stage2 归约输出 fp32 后再转换为 fp16 造成了冗余的 float16_copy 内核启动。需要优化以支持更长上下文并提高性能。

值得精读,展示了注意力后端中从逐层缓冲区到全局 WorkSpace 管理的演进,以及 Triton 核函数级别的优化技巧。对于关注推理引擎性能优化的工程师有参考价值。

讨论亮点

唯一的 review 来自 gemini-code-assist[bot],建议在 _continuation_prefill 中也使用 current_workspace_manager().get_simultaneous() 来分配 k_fullv_full,以避免每次请求分配大张量,从而进一步降低长上下文的内存峰值。作者未直接回应,但 PR 仍被批准合并,可能留待后续优化。

实现拆解

  1. 缓冲区管理迁移至 WorkspaceManager:在 turboquant_attn.py 中,将原先每个注意力层独立注册的 _tq_mid_o_buf_tq_output_buf_tq_lse_buf 以及去量化缓冲区(_tq_k_dequant_buf_tq_v_dequant_buf)改为通过 current_workspace_manager().get_simultaneous() 分配,实现跨层共享。同时移除了 attention.py 中的 _init_turboquant_buffers 方法。

  2. 消除 float16_copy 内核:在 triton_decode_attention.py_fwd_kernel_stage2 中添加 OUTPUT_FP16 编译常量,当输出 dtype 为 fp16 时直接在核函数内部进行类型转换,省去了外部的 .to(query.dtype) 调用。对应的调用处 triton_turboquant_decode_attention 也做了适配,要求 output_buf 的 dtype 与 query 一致。

  3. 缓存辅助张量:在 _prefill_attention 中缓存 cu_seqlens 张量(self._cu_2)和 arange 张量(self._arange_cache),避免每次调用时重新创建,减少 host→device 的传输和小型核函数启动。

  4. fp16 Hadamard 旋转:在 _ensure_on_device 中预先计算 layer._tq_Pi_half = H.to(torch.float16),并在 _continuation_prefill 中直接使用 fp16 矩阵乘法进行旋转,替代原先的 fp32 乘法,降低带宽消耗。

  5. 预分配连续缓冲区:在 _continuation_prefill 中预分配 k_fullv_full(seq_len, Hk, D),直接将缓存和新数据写入切片,消除了 torch.cat.contiguous() 的临时分配。

文件 模块 状态 重要度
vllm/model_executor/layers/attention/attention.py 注意力层 modified 7.54
vllm/v1/attention/backends/turboquant_attn.py 注意力后端 modified 7.37
vllm/v1/attention/ops/triton_turboquant_decode.py Triton 内核 modified 4.09
vllm/v1/attention/ops/triton_decode_attention.py Triton 内核 modified 3.69

关键符号

_init_turboquant_buffers (deleted) _ensure_on_device _prefill_attention _continuation_prefill triton_turboquant_decode_attention _fwd_kernel_stage2

关键源码片段

vllm/v1/attention/backends/turboquant_attn.py dependency-wiring

核心修改文件,全部优化逻辑的主体:引入 WorkspaceManager、缓存辅助张量、fp16 Hadamard、预分配连续缓冲区等。

# _ensure_on_device 中的缓冲区初始化片段(head 版本):
def _ensure_on_device(self, layer, device):
    if not hasattr(layer, "_tq_cached"):
        D = self.head_size
        H = _build_hadamard(D, str(device))
        layer._tq_PiT = H
        layer._tq_Pi = H
        # 新增:fp16 副本用于续写 prefill 中的旋转,减少带宽
        layer._tq_Pi_half = H.to(torch.float16)
​
        # 通过 WorkspaceManager 分配共享缓冲区(全局每 ubatch 一组)
        # 替代原先每个层独立 register_buffer
        # 形状:(B, Hq, S, D+1) mid_o / (B, Hq, D) output / (B, Hq) lse
        # 实际分配由 get_simultaneous 统一管理
​
        # Centroids 也延迟到此处初始化
        layer._tq_centroids = get_centroids(
            D, self.tq_config.centroid_bits
        ).to(device=device, dtype=torch.float32)
        c_sorted, _ = layer._tq_centroids.sort()
        layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2
        layer._tq_cached = True
# _prefill_attention 中的缓存 arange 和 cu_seqlens 优化(head 版本):
if not hasattr(self, "_cu_2"):
    self._cu_2 = torch.zeros(2, device=query.device, dtype=torch.int32)
_max_seq = attn_metadata.max_seq_len
_ac: torch.Tensor | None = getattr(self, "_arange_cache", None)
if _ac is None or _ac.shape[0] <= _max_seq:
    _ac = torch.arange(
        0, _max_seq + 1, device=query.device,
        dtype=attn_metadata.seq_lens.dtype
    )
    self._arange_cache = _ac
_arange_cache: torch.Tensor = _ac# 后续构造 synth_seq_lens 时直接切片,避免 torch.arange 启动
synth_seq_lens = _arange_cache[cached_len + 1 : seq_len + 1]

评论区精华

在 continuation prefill 中使用 WorkspaceManager 以减少分配 性能

gemini-code-assist[bot] 建议在 `_continuation_prefill` 中也使用 `current_workspace_manager().get_simultaneous()` 分配 `k_full` 和 `v_full`,以避免每次请求分配大张量,进一步降低长上下文内存峰值。

结论:作者未回复,PR 已合并,该优化可能留待后续 PR。 · 待处理

风险与影响

风险较低。主要风险:1)WorkspaceManager 的依赖:如果 workspace 未正确初始化或版本不兼容,可能导致崩溃,但 vLLM v1 默认启用该组件。2)CUDA Graph 捕获:此前因缓冲区地址变化无法捕获,本 PR 通过预分配修复了此问题,但若其他动态分配代码未正确处理仍可能失败。3)无新增测试覆盖:仅依赖现有 117 个 TQ 单元测试,未对共享缓冲区逻辑添加专项测试,存在回归隐患。

直接影响使用 TurboQuant KV 缓存 dtype 的用户,在长上下文推理时可节省数十 GB 显存并提升 prefill 速度(如 15K 上下文 TTFT 降低 14%)。decode 吞吐量基本不变。非 TQ 后端完全不受影响。该 PR 也是后续在 TQ 后端支持 CUDA Graph 的前提。

WorkspaceManager 依赖 CUDA Graphs 兼容性 缺少新增测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论