Prhub

#25299 [NSA] Avoid repeated NSA MQA logits memory queries

原始 PR 作者 YAMY1234 合并时间 2026-05-20 07:04 文件变更 1 提交数 2 评论 22 代码增减 +61 / -19

执行摘要

缓存 MQA logits 内存预算,避免重复 host sync

NSA indexer 在大的 MQA logits 路径上会频繁调用 torch.cuda.mem_get_info,这是一个 host-syncing 查询,位于延迟敏感的 prefill 路径中,影响性能。(Reported-by: @samuellees)

本 PR 展示了如何通过缓存避免 host-device 同步来优化延迟敏感路径,设计简洁效果明显。建议阅读以学习性能优化技巧,尤其对于涉及 GPU 内存查询的场景。同时,关注缓存预算的计算方式,可参考其双预算约束设计。

讨论亮点
  • samuellees 要求将魔数 0.30.5 提取为类常量,作者已采纳并添加 _MQA_LOGITS_FREE_MEM_FRACTION_MQA_LOGITS_TOTAL_MEM_FRACTION
  • Fridge003 关注缓存预算在首次大 batch 时可能过松,导致后续 chunk 不足而 OOM。作者解释:缓存仅在非 CUDA graph capture 路径上由大窗口触发,且结果会被工作量无关的静态服务内存头寸(mem_fraction_static)限制,不会过松。该问题已解决。

实现拆解

  1. Indexer 类中添加类级常量,将原先硬编码的 0.50.38_000_000 等提取为 _MQA_LOGITS_FREE_MEM_FRACTION_MQA_LOGITS_TOTAL_MEM_FRACTION_MQA_LOGITS_STATIC_SKIP_ELEMS
  2. 新增 _get_mqa_logits_budget_bytes(self, device_index) 方法:首先检查 _mqa_logits_budget_bytes 字典缓存;若未缓存,则根据设备总内存和 mem_fraction_static 计算静态服务内存预算,并在 CUDA graph capture 模式返回静态预算而不缓存;否则,执行一次 torch.cuda.mem_get_info 获取当前空闲内存,取两者较小值作为预算,并缓存到字典中。
  3. 重构 _should_chunk_mqa_logits 方法:将参数 device 改为 device_index;使用缓存的预算替代每次调用 mem_get_info;将原始判断条件 logits_bytes * 2 > free_mem or logits_bytes > total_mem * 0.3 简化为 logits_bytes > logits_budget_bytes
  4. _get_topk_ragged 中复用 cu_seqlens_q 的分配,避免在 chunked paged 路径的循环内部重复分配。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/nsa/nsa_indexer.py NSA 模块 modified 7.29

关键符号

_get_mqa_logits_budget_bytes _should_chunk_mqa_logits

关键源码片段

python/sglang/srt/layers/attention/nsa/nsa_indexer.py core-logic

所有变更集中于此文件:类常量、缓存方法和 chunk 逻辑重构。

class Indexer(MultiPlatformOp):
    # 类常量,替代魔数
    _MQA_LOGITS_BYTES_PER_ELEM = 4
    _MQA_LOGITS_STATIC_SKIP_ELEMS = 8_000_000 # 跳过小 batch 的静态阈值
    _MQA_LOGITS_FREE_MEM_FRACTION = 0.5 # 空闲内存使用上限比例
    _MQA_LOGITS_TOTAL_MEM_FRACTION = 0.3 # 总内存使用上限比例
    _mqa_logits_budget_bytes: Dict[int, int] = {} # 每设备缓存字典
​
    def _get_mqa_logits_budget_bytes(self, device_index: int) -> int:
        """获取缓存的内存预算,避免每次 host sync"""
        # 优先返回缓存值
        cached_budget = self._mqa_logits_budget_bytes.get(device_index)
        if cached_budget is not None:
            return cached_budget
​
        total_mem = torch.cuda.get_device_properties(device_index).total_memory
        # 基于总内存的预算
        total_mem_budget = int(total_mem * self._MQA_LOGITS_TOTAL_MEM_FRACTION)
        mem_fraction_static = get_global_server_args().mem_fraction_static
        if mem_fraction_static is None:
            static_budget = total_mem_budget
        else:
            static_free_mem = int(total_mem * max(0.0, 1.0 - mem_fraction_static))
            # 静态预算取空闲内存比例与总内存比例的较小值,避免过度松弛
            static_budget = min(
                int(static_free_mem * self._MQA_LOGITS_FREE_MEM_FRACTION),
                total_mem_budget,
            )
        static_budget = max(1, static_budget)
​
        # CUDA graph capture 阶段不缓存,返回静态预算
        if get_is_capture_mode():
            return static_budget
​
        # 首次非 capture 时查询一次并缓存
        free_mem, _ = torch.cuda.mem_get_info(device_index)
        budget_bytes = min(
            int(free_mem * self._MQA_LOGITS_FREE_MEM_FRACTION), static_budget
        )
        budget_bytes = max(1, budget_bytes)
        self._mqa_logits_budget_bytes[device_index] = budget_bytes
        return budget_bytes
​
    def _should_chunk_mqa_logits(
        self, num_q: int, num_k: int, device_index: int
    ) -> Tuple[bool, int]:
        """判断是否需要分块计算 MQA logits 以避免 OOM
        Returns: (need_chunk, logits_budget_bytes)
        """
        # 小 batch 直接跳过,避免不必要的开销
        if num_q * num_k < self._MQA_LOGITS_STATIC_SKIP_ELEMS:
            return False, 0
​
        logits_bytes = num_q * num_k * self._MQA_LOGITS_BYTES_PER_ELEM
        logits_budget_bytes = self._get_mqa_logits_budget_bytes(device_index)
        need_chunk = logits_bytes > logits_budget_bytes
        return need_chunk, logits_budget_bytes

评论区精华

将魔数提取为类常量 style

samuellees 要求将 0.3 和 0.5 命名为常量。

结论:作者添加了对应常量并回复 'Adjusted'。 · 已解决

缓存预算的过松风险 设计

Fridge003 担心首次大 batch 时缓存的值可能过松,导致后续 chunk 不足而 OOM。

结论:作者解释缓存仅在非 capture 路径上计算,且受 static budget cap,不会过松。 · 已解决

风险与影响

  1. 缓存过时风险:缓存的值可能在长期运行中因内存碎片或并发请求而变得不准确,但通过 mem_fraction_static 和总内存比例双重限制,OOM 风险较低。
  2. 首次查询仍有一次 sync:第一个大 batch 非 capture 路径仍有一次 mem_get_info 开销,但仅一次,可接受。
  3. 测试覆盖不足:未添加直接单元测试,但已有 E2E 测试覆盖 NSA 路径,且性能测试验证了正确性。
  4. 多线程安全_mqa_logits_budget_bytes 为类字典,device_index 唯一且写后读,无竞争条件。
  • 用户:DeepSeek 模型在长上下文预填充阶段获得 6-10% 延迟改善和吞吐提升,尤其大 batch 场景受益显著。
  • 系统:减少 GPU 同步次数,提高调度效率和 GPU 利用率。
  • 团队:代码更易维护,魔数提取为常量便于未来调参。
缓存依赖 GPU 属性 缺少测试覆盖 核心路径变更

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论