执行摘要
- 一句话:缓存 MQA logits 内存预算,避免重复 host sync
- 推荐动作:本 PR 展示了如何通过缓存避免 host-device 同步来优化延迟敏感路径,设计简洁效果明显。建议阅读以学习性能优化技巧,尤其对于涉及 GPU 内存查询的场景。同时,关注缓存预算的计算方式,可参考其双预算约束设计。
功能与动机
NSA indexer 在大的 MQA logits 路径上会频繁调用 torch.cuda.mem_get_info,这是一个 host-syncing 查询,位于延迟敏感的 prefill 路径中,影响性能。(Reported-by: @samuellees)
实现拆解
- 在
Indexer 类中添加类级常量,将原先硬编码的 0.5、0.3、8_000_000 等提取为 _MQA_LOGITS_FREE_MEM_FRACTION、_MQA_LOGITS_TOTAL_MEM_FRACTION、_MQA_LOGITS_STATIC_SKIP_ELEMS。
- 新增
_get_mqa_logits_budget_bytes(self, device_index) 方法:首先检查 _mqa_logits_budget_bytes 字典缓存;若未缓存,则根据设备总内存和 mem_fraction_static 计算静态服务内存预算,并在 CUDA graph capture 模式返回静态预算而不缓存;否则,执行一次 torch.cuda.mem_get_info 获取当前空闲内存,取两者较小值作为预算,并缓存到字典中。
- 重构
_should_chunk_mqa_logits 方法:将参数 device 改为 device_index;使用缓存的预算替代每次调用 mem_get_info;将原始判断条件 logits_bytes * 2 > free_mem or logits_bytes > total_mem * 0.3 简化为 logits_bytes > logits_budget_bytes。
- 在
_get_topk_ragged 中复用 cu_seqlens_q 的分配,避免在 chunked paged 路径的循环内部重复分配。
关键文件:
python/sglang/srt/layers/attention/nsa/nsa_indexer.py(模块 NSA模块;类别 source;类型 core-logic;符号 _get_mqa_logits_budget_bytes, _should_chunk_mqa_logits): 所有变更集中于此文件:类常量、缓存方法和 chunk 逻辑重构。
关键符号:_get_mqa_logits_budget_bytes, _should_chunk_mqa_logits
关键源码片段
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
所有变更集中于此文件:类常量、缓存方法和 chunk 逻辑重构。
class Indexer(MultiPlatformOp):
# 类常量,替代魔数
_MQA_LOGITS_BYTES_PER_ELEM = 4
_MQA_LOGITS_STATIC_SKIP_ELEMS = 8_000_000 # 跳过小 batch 的静态阈值
_MQA_LOGITS_FREE_MEM_FRACTION = 0.5 # 空闲内存使用上限比例
_MQA_LOGITS_TOTAL_MEM_FRACTION = 0.3 # 总内存使用上限比例
_mqa_logits_budget_bytes: Dict[int, int] = {} # 每设备缓存字典
def _get_mqa_logits_budget_bytes(self, device_index: int) -> int:
"""获取缓存的内存预算,避免每次 host sync"""
# 优先返回缓存值
cached_budget = self._mqa_logits_budget_bytes.get(device_index)
if cached_budget is not None:
return cached_budget
total_mem = torch.cuda.get_device_properties(device_index).total_memory
# 基于总内存的预算
total_mem_budget = int(total_mem * self._MQA_LOGITS_TOTAL_MEM_FRACTION)
mem_fraction_static = get_global_server_args().mem_fraction_static
if mem_fraction_static is None:
static_budget = total_mem_budget
else:
static_free_mem = int(total_mem * max(0.0, 1.0 - mem_fraction_static))
# 静态预算取空闲内存比例与总内存比例的较小值,避免过度松弛
static_budget = min(
int(static_free_mem * self._MQA_LOGITS_FREE_MEM_FRACTION),
total_mem_budget,
)
static_budget = max(1, static_budget)
# CUDA graph capture 阶段不缓存,返回静态预算
if get_is_capture_mode():
return static_budget
# 首次非 capture 时查询一次并缓存
free_mem, _ = torch.cuda.mem_get_info(device_index)
budget_bytes = min(
int(free_mem * self._MQA_LOGITS_FREE_MEM_FRACTION), static_budget
)
budget_bytes = max(1, budget_bytes)
self._mqa_logits_budget_bytes[device_index] = budget_bytes
return budget_bytes
def _should_chunk_mqa_logits(
self, num_q: int, num_k: int, device_index: int
) -> Tuple[bool, int]:
"""判断是否需要分块计算 MQA logits 以避免 OOM
Returns: (need_chunk, logits_budget_bytes)
"""
# 小 batch 直接跳过,避免不必要的开销
if num_q * num_k < self._MQA_LOGITS_STATIC_SKIP_ELEMS:
return False, 0
logits_bytes = num_q * num_k * self._MQA_LOGITS_BYTES_PER_ELEM
logits_budget_bytes = self._get_mqa_logits_budget_bytes(device_index)
need_chunk = logits_bytes > logits_budget_bytes
return need_chunk, logits_budget_bytes
评论区精华
风险与影响
- 风险:
- 缓存过时风险:缓存的值可能在长期运行中因内存碎片或并发请求而变得不准确,但通过
mem_fraction_static 和总内存比例双重限制,OOM 风险较低。
- 首次查询仍有一次 sync:第一个大 batch 非 capture 路径仍有一次
mem_get_info 开销,但仅一次,可接受。
- 测试覆盖不足:未添加直接单元测试,但已有 E2E 测试覆盖 NSA 路径,且性能测试验证了正确性。
- 多线程安全:
_mqa_logits_budget_bytes 为类字典,device_index 唯一且写后读,无竞争条件。
- 影响:
- 用户:DeepSeek 模型在长上下文预填充阶段获得 6-10% 延迟改善和吞吐提升,尤其大 batch 场景受益显著。
- 系统:减少 GPU 同步次数,提高调度效率和 GPU 利用率。
- 团队:代码更易维护,魔数提取为常量便于未来调参。
- 风险标记:缓存依赖GPU属性, 缺少测试覆盖, 核心路径变更
关联脉络
参与讨论