Prhub

#26513 Fix FlashInfer SWA EXTEND-with-prefix correctness in merge_state path

原始 PR 作者 ch-wan 合并时间 2026-05-28 16:16 文件变更 1 提交数 1 评论 3 代码增减 +75 / -8

执行摘要

修复 FlashInfer SWA EXTEND-with-prefix 的正确性

PR body 指出:FlashInfer SWA EXTEND-with-prefix 且 use_ragged=True 时 (merge_state 分支) 返回错误输出,在模块级测试中与 HF PyTorch 参考的 max abs diff 约 0.2,并在服务级 SWA 负载中产生静默错误 token。

建议精读。该 PR 精准定位并修复了一个在 SWA + 前缀缓存组合场景下的静默错误,涉及 FlashInfer 滑动窗口注意力的内部协作细节,对理解 FlashInfer 后端的多 wrapper merge_state 路径有较高参考价值。

讨论亮点

Review 机器人提出两条建议:(1) 自定义 mask 应保持 torch.bool 而非 uint8,因为 custom_mask API 预期 unpacked boolean mask;(2) 应限制 paged 前缀长度到滑动窗口大小,防止在长前缀场景下分配过大 mask。

实现拆解

  1. forward_extend 传入 window_left:在 use_ragged=True 分支中计算 swa_window_left,与 use_ragged=False 分支保持一致,并传入两个 forward_return_lse 调用,确保每个 wrapper 都应用滑动窗口遮罩。
  2. update_sliding_window 分支处理 use_ragged=True:对于 use_ragged=True,将 paged wrapper 看到的 K/V 限制为前缀中最后 window 个 token,避免读取未初始化的扩展 token 缓存位置。
  3. 新增 _build_swa_prefix_custom_mask 辅助函数:为每个 extend query 构建基于全局位置的自定义遮罩 mask,确保 Q 只关注 cache 前缀中落在窗口内的 key。
  4. 仅修改 flashinfer_backend.py:所有变更集中在单个文件,无测试或配置配套。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/flashinfer_backend.py 注意力层 modified 7.59

关键符号

forward_extend update_sliding_window _build_swa_prefix_custom_mask

关键源码片段

python/sglang/srt/layers/attention/flashinfer_backend.py core-logic

所有变更集中于此,修复了 SWA EXTEND-with-prefix 的 merge_state 路径

以下是关键修改的代码片段,集中在 flashinfer_backend.py 中,展示了 forward_extend 中传入 window_left 以及 update_sliding_window 中对 use_ragged=True 的分支处理。

# forward_extend 方法中,两个 forward_return_lse 调用都传入 window_left
swa_window_left = (
    layer.sliding_window_size
    if not (
        self.forward_metadata.multi_item_params
        and self.forward_metadata.multi_item_params.is_enabled()
    )
    else -1
)
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
    q.view(-1, layer.tp_q_head_num, layer.head_dim),
    k.view(-1, layer.tp_k_head_num, layer.head_dim),
    v.view(-1, layer.tp_v_head_num, layer.head_dim),
    causal=causal,
    sm_scale=layer.scaling,
    window_left=swa_window_left, # 之前缺失,导致无滑动窗口遮罩
    logits_soft_cap=logits_soft_cap,
)
o2, s2 = prefill_wrapper_paged.forward_return_lse(
    q.view(-1, layer.tp_q_head_num, layer.head_dim),
    self.token_to_kv_pool.get_kv_buffer(layer.layer_id),
    causal=False,
    sm_scale=layer.scaling,
    window_left=swa_window_left, # 之前缺失
    logits_soft_cap=logits_soft_cap,
)# update_sliding_window 中对 use_ragged=True 的分支处理
for wrapper_id in range(2):
    swa_paged_custom_mask = None
    if wrapper_id == 0:
        if use_ragged:
            # 将 paged 部分限制为前缀中最后的 window 个 token
            effective_start = torch.clamp(
                prefix_lens - self.sliding_window_size, min=0
            )
            paged_kernel_lens = prefix_lens - effective_start
            paged_kernel_lens_sum = paged_kernel_lens.sum().item()
            kv_start_idx = effective_start
            swa_paged_custom_mask = self._build_swa_prefix_custom_mask(
                prefix_lens, seq_lens, effective_start
            )
        else:
            # 原有逻辑:窗口注意力使用 paged
            paged_kernel_lens = torch.minimum(
                seq_lens,
                torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
            )
            paged_kernel_lens_sum = paged_kernel_lens.sum().item()
            kv_start_idx = seq_lens - paged_kernel_lens
    else:
        paged_kernel_lens = seq_lens
        paged_kernel_lens_sum = seq_lens_sum
        kv_start_idx = seq_lens - paged_kernel_lens# 新增辅助函数 _build_swa_prefix_custom_mask
def _build_swa_prefix_custom_mask(
    self,
    prefix_lens: torch.Tensor,
    seq_lens: torch.Tensor,
    kv_start_idx: torch.Tensor,
) -> Optional[torch.Tensor]:
    """构造用于 paged wrapper 的自定义 SWA mask。    Paged KV 覆盖全局位置 [kv_start_idx[i], prefix_lens[i]),
    此 mask 为每个 extend query 限制其只关注窗口内的前缀 key。
    当所有 key 都在窗口内时返回 None。
    """
    window = self.sliding_window_size
    if window is None or window <= 0:
        return None
    # 计算每个 query 的允许 key 范围,构建 boolean mask
    # ( 简化示意,实际实现按 batch 向量化 )
    ...
    return mask # torch.bool tensor

评论区精华

自定义 mask 数据类型应为 torch.bool 而非 uint8 正确性

Review 机器人指出,custom_mask API 预期 unpacked boolean mask,若使用 uint8 会被误判为 packed mask 路径,导致遮罩失效。

结论:建议保持 torch.bool 类型,作者未回复但代码最终使用了 torch.bool(patch 中 mask 构造函数返回布尔张量)。 · 已解决

应限制 paged 前缀长度到滑动窗口大小 性能

Review 机器人建议将 paged 部分的 KV 长度限制到 window 内,避免长前缀场景下分配过大的 mask 矩阵。

结论:作者已实现该优化:通过 effective_start 截断到窗口大小。 · 已解决

风险与影响

回归风险:对 use_ragged=False 路径无变更,非 SWA 路径不受影响;多 item scoring (multi_item_params) 场景下 window_left 设为 -1,与原有行为一致。性能风险:新增的 mask 构建和零填充操作是 O(num_requests),可忽略。兼容性:仅修改内部逻辑接口,无配置或 API 变更。

影响范围:修复 FlashInfer 滑动窗口注意力在 EXTEND-with-prefix 路径下的正确性,影响使用前缀缓存的 SWA 模型用户。影响程度:高,修复了可能导致静默错误输出的严重 bug。

核心路径变更 缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论