Prhub

#24138 [SWA] Ensure we use pre-computed SWA cache location during prefill

原始 PR 作者 merrymercy 合并时间 2026-05-01 15:01 文件变更 1 提交数 2 评论 3 代码增减 +8 / -4

执行摘要

修复 prefill 阶段 SWA cache location 被忽略问题

在 sliding-window attention 模型中,prefill 阶段已经通过 SWAKVPool 预计算了 SWA cache location 并存储在 forward_batch.out_cache_loc_swa 中,但 _get_layer_cache_loc 方法始终调用 translate_loc_from_full_to_swa 从完整 cache location 重新转换,忽略了预计算的值。这会导致在不支持重新转换的 CUDA graph 等场景下出现错误。PR 从 meta-llama/prod_inference 上游同步修复。

建议精读。此 PR 展示了一个典型的“使用预计算值替代重复计算”的优化模式,同时也体现了 review 中发现的“直接引用状态属性 vs 通过 forward_batch 传递”的设计陷阱。对于维护 SWA 或类似缓存机制的同学,该变更和讨论值得学习。

讨论亮点

Review 中 gemini-code-assist[bot] 指出,如果直接使用 self._swa_kv_pool.swa_loc 作为缓存位置,会引入跨批次使用过期数据(stale data)的风险。具体来说,swa_locSWAKVPool 的状态属性且是批次相关的,在 ModelRunner._forward_raw 中仅当当前批次提供新值时才会更新,从未被清除。若后续批次未提供新的 swa_loc,该方法将返回上一批次的 swa_loc,可能导致错误的 KV cache 写入或形状不匹配。最终提交的代码已采纳该建议,改用 forward_batch.out_cache_loc_swa(从 ForwardBatch 中获取)代替 self._swa_kv_pool.swa_loc,避免了状态残留问题。

实现拆解

  1. 修改 _get_layer_cache_loc 方法:将参数从 cache_loc: torch.Tensor 改为 forward_batch: ForwardBatch,以便访问 forward_batch.out_cache_loc_swaforward_batch.out_cache_loc
  2. 优先返回预计算的 SWA cache location:当 forward_batch.out_cache_loc_swa 不为 None 时,直接返回该值,避免调用 translate_loc_from_full_to_swa 进行冗余转换。
  3. 更新调用点 _fused_fp8_set_kv_buffer:将 self._get_layer_cache_loc(layer, forward_batch.out_cache_loc) 调整为 self._get_layer_cache_loc(layer, forward_batch),以匹配新的函数签名。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/trtllm_mha_backend.py 注意力层 modified 5.97

关键符号

_get_layer_cache_loc _fused_fp8_set_kv_buffer

关键源码片段

python/sglang/srt/layers/attention/trtllm_mha_backend.py core-logic

唯一修改的文件,核心变更在 `_get_layer_cache_loc` 方法及其调用点 `_fused_fp8_set_kv_buffer`,修复 SWA prefill 阶段缓存位置被忽略的 bug。

def _get_layer_cache_loc(
    self,
    layer: RadixAttention,
    forward_batch: ForwardBatch,
) -> torch.Tensor:
    """Return cache locations in the correct index space for the given layer.    如果该层是 SWA 层,优先返回预计算的 SWA cache location(`out_cache_loc_swa`),
    避免从完整 cache location 重新转换,后者在 CUDA graph 等场景下可能不准确。
    """
    if self.use_sliding_window_kv_pool:
        _, is_swa = self._swa_kv_pool.layers_mapping[layer.layer_id]
        if is_swa:
            # 如果 forward_batch 已经预计算了 SWA cache location,直接返回
            if forward_batch.out_cache_loc_swa is not None:
                return forward_batch.out_cache_loc_swa
            # 否则从完整 cache location 实时转换(作为 fallback)
            return self._swa_kv_pool.translate_loc_from_full_to_swa(
                forward_batch.out_cache_loc
            )
    # 非 SWA 层,直接返回完整 cache location
    return forward_batch.out_cache_loc

评论区精华

使用 forward_batch 传递预计算位置 vs 直接引用 SWAKVPool 状态属性 设计

gemini-code-assist[bot] 指出,若直接使用 `self._swa_kv_pool.swa_loc` 会引入跨批次使用过期数据(stale data)的风险,因为 `swa_loc` 是状态属性且批次相关,但在 `ModelRunner._forward_raw` 中从未被清除。

结论:最终代码采用 `forward_batch.out_cache_loc_swa`(由 ForwardBatch 传递),而不是直接引用 `SWAKVPool` 的状态属性,避免了状态残留问题。 · 已解决

风险与影响

风险极低。变更仅涉及一个私有方法 _get_layer_cache_loc 及其调用点,改动量小。主要风险在于:若 forward_batch.out_cache_loc_swaout_cache_loc 的语义不一致(例如在非 SWA 层或 decode 阶段),但现有逻辑已通过 is_swa 判断和 None 检查加以保护,不会误用。此外,该路径仅在 use_sliding_window_kv_pool 为 True 时生效,不影响非 SWA 模型。

影响范围限于使用 TRTLLM 注意力后端且启用 sliding-window KV pool 的模型(如某些 Blackwell 平台上的模型)。对用户透明,但可修复在 CUDA graph 等场景下 SWA 模型 prefill 阶段未使用预计算缓存位置导致的正确性问题。无性能回退,因为免去了多余的坐标转换调用反而可能带来轻微性能提升。

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论