执行摘要
本PR在FlashAttention后端中引入fa_skip_kv_cache标志,在嵌入模式且禁用缓存时跳过KV缓存读写,直接使用原始K/V张量计算注意力,消除每层约19µs的开销,提升推理性能,同时确保不影响其他后端。
功能与动机
动机源于嵌入模式下KV缓存的不必要操作:当配置--chunked-prefill-size -1和--disable-radix-cache时,每个请求仅为单次prefill,KV缓存被写入和读取但从未重用,浪费约19µs每层(store_kvcache ~15µs + prepare_varlen ~4µs)。优化旨在消除此开销,提升嵌入模型推理效率。
实现拆解
实现集中于flashattention_backend.py文件,关键改动点:
- 条件标志添加:在
__init__中添加self.fa_skip_kv_cache,基于server_args.is_embedding、server_args.chunked_prefill_size == -1和server_args.disable_radix_cache判断。
- 缓存写入跳过:在
forward_extend中,修改条件save_kv_cache and not is_cp_mode and not self.fa_skip_kv_cache以跳过set_kv_buffer调用。
- 注意力计算优化:新增代码路径使用
flash_attn_varlen_func代替flash_attn_with_kvcache,直接处理原始K/V张量,并添加断言确保FP8 KV缓存descaling不支持。
评论区精华
Review讨论聚焦于正确性保障:
Qiaolin-Yu: "should we assert k_descale, v_descale, and num_splits are none here? since in previous path, these attributes are passed in"
jasperjiaguo: "Good catch, I passed in num_splits=self.num_splits, so it uses the split heuristics similarly. For fp8 kv it's not relevant for this path w/o kv cache so just added an assert to not silently skip."
结论是添加断言处理FP8支持并传递num_splits参数,确保优化路径安全。
风险与影响
风险:
- 条件判断错误可能导致缓存未填充,影响radix cache查找。
- FP8 KV缓存descaling不支持,限制未来扩展。
- 仅针对FlashAttention后端,需确保逻辑隔离避免兼容性问题。
影响:
- 用户:在特定嵌入模式下性能提升,但仅限于正确配置的场景。
- 系统:减少GPU内核调用,降低计算开销,对整体吞吐量有积极影响。
- 团队:代码维护成本轻微增加,但优化路径清晰。
关联脉络
与历史PR关联显示持续的性能优化趋势:
- PR #22517:优化TRT-LLM attention后端,类似关注计算效率提升。
- PR #22645:添加环境变量控制缓存淘汰间隔,平衡内存与性能开销。
这些PR共同体现了sglang仓库在attention计算和缓存管理上的持续优化方向。
参与讨论