Prhub

#40534 [Model] Gemma4: add bidirectional vision attention for sliding layers with window guard

原始 PR 作者 lucianommartins 合并时间 2026-04-24 16:27 文件变更 2 提交数 4 评论 8 代码增减 +73 / -1

执行摘要

Gemma4 双向视觉注意力支持及滑动窗口守卫

Gemma4 架构要求仅在 sliding_attention 层对视觉 token 应用双向注意力,这与 HF transformers 参考实现一致。直接使用现有 bidi 方案会导致全注意力层错误地获得双向注意力,并在图像 token 超过 sliding_window 时出现注意力爆炸,因此需要精确控制。见 issue #40106。

该 PR 实现清晰,注释详实,测试数据充分。建议开发者重点关注 _clear_mm_prefix_for_full_attn_layers 的设计模式:在 compiled graph 外部管理注意力元数据,避免侵入 torch.compile 区域。对多模态模型研发者具有参考价值。

讨论亮点
  • 性能优化:gemini-code-assist 建议避免在热路径中使用正则表达式和缺失 hasattr 检查。最终实现采用 frozenset 预计算索引,并在设置前检查属性存在性。
  • 架构决策:Isotr0py 认为在核心 model runner 中添加模型特定守卫(sliding window guard)比较 hacky,建议改由 triton kernel 正确支持 SWA+bidi。lucianommartins 同意但作为两步方案,优先合并守卫以快速解决问题。
  • 通用化:Isotr0py 指出守卫不仅限于 Gemma4,适用于所有结合 bidi 和滑动窗口的模型,建议移除模型类型检查。最终守卫只依赖 sliding_window 配置,不绑定模型。
  • 简化可能性:Isotr0py 提到可通过 PR#40701 简化实现,但未展开。

实现拆解

  1. 预计算全注意力层索引:在 Gemma4ForConditionalGeneration.__init__ 中解析 layer_types 配置,将非 sliding_attention 的层索引存入 _full_attn_layer_idxs(frozenset),避免每次 forward 时重复解析。
  2. 清除全注意力层的 mm_prefix_range:新增 _clear_mm_prefix_for_full_attn_layers 方法,在 forward 中(@support_torch_compile 边界外)调用,通过遍历注意力元数据字典,对层名提取索引,若属于全注意力层则置空 mm_prefix_rangemm_prefix_range_tensor,从而恢复因果掩码。
  3. 滑动窗口守卫:在 gpu_model_runner.py_build_attn_group_metadata 方法中,收集图像范围时增加检查:若范围长度超过 sliding_window 配置值,则跳过该范围,不加入 req_doc_ranges。这防止了超出窗口的图像 token 使用双向注意力导致的精度回归。
  4. 注册 MM_PREFIX_LM_MODELS:将 gemma4 加入该列表,以启用 mm_prefix_range 的自动填充机制。
文件 模块 状态 重要度
vllm/model_executor/models/gemma4_mm.py 多模态模型 modified 7.91
vllm/v1/worker/gpu_model_runner.py 模型执行 modified 6.25

关键符号

_clear_mm_prefix_for_full_attn_layers _process forward _build_attn_group_metadata

关键源码片段

vllm/model_executor/models/gemma4_mm.py core-logic

核心模型文件,实现 bidi 核心逻辑:预计算全注意力层索引、清除 mm_prefix_range、修改 forward 流程。

# gemma4_mm.py — 预计算全注意力层索引及清除元数据# 在 __init__ 中:
self._full_attn_layer_idxs: frozenset[int] = frozenset()
text_config = config.text_config
if getattr(text_config, 'use_bidirectional_attention', None) == 'vision':
    layer_types = getattr(text_config, 'layer_types', None)
    if layer_types:
        self._full_attn_layer_idxs = frozenset(
            i for i, lt in enumerate(layer_types) if lt != 'sliding_attention'
        )def _clear_mm_prefix_for_full_attn_layers(self) -> None:
    '''清除全注意力层的 mm_prefix_range 以强制因果掩码。    Gemma4 使用 `use_bidirectional_attention='vision'` 时只在
    sliding_attention 层启用双向注意力,全注意力层必须保持因果。
    该方法必须在 forward 调用之前执行(位于 @support_torch_compile
    边界外),因为编译器内部无法携带 Python 副作用。
    '''
    if not self._full_attn_layer_idxs:
        return
​
    from vllm.forward_context import get_forward_context
    attn_metadata = get_forward_context().attn_metadata
    if attn_metadata is None:
        return
​
    def _process(metadata_dict: dict) -> None:
        for layer_name, metadata in metadata_dict.items():
            # 从层名如 'model.layers.12.self_attn' 提取索引
            if '.layers.' not in layer_name:
                continue
            try:
                layer_idx = int(
                    layer_name.split('.layers.')[1].split('.')[0]
                )
            except (ValueError, IndexError):
                continue
            if layer_idx in self._full_attn_layer_idxs:
                if hasattr(metadata, 'mm_prefix_range'):
                    metadata.mm_prefix_range = None
                if hasattr(metadata, 'mm_prefix_range_tensor'):
                    metadata.mm_prefix_range_tensor = None
​
    if isinstance(attn_metadata, list):
        for ub_metadata in attn_metadata:
            _process(ub_metadata)
    elif isinstance(attn_metadata, dict):
        _process(attn_metadata)
vllm/v1/worker/gpu_model_runner.py core-logic

通用 model runner,添加滑动窗口守卫以跳过超过窗口大小的图像范围,防止 bidi 导致注意力爆炸。

# gpu_model_runner.py 内的 _build_attn_group_metadata 方法if self.is_mm_prefix_lm:
    req_doc_ranges = {}
​
    # 滑动窗口守卫:当图像 token 数超过 sliding_window 时,bidi
    # 会导致早期 token 关注整个图像(例如 6 → 1092 目标),
    # 降低空间精度。按范围过滤可对小图像 / 视频帧保持 bidi,
    # 同时跳过过大的图像范围。
    hf_text_config = self.model_config.hf_text_config
    _bidi_sw = getattr(hf_text_config, 'sliding_window', None)
​
    for req_id in self.input_batch.req_ids:
        image_doc_ranges = []
        req_state = self.requests[req_id]
        for mm_feature in req_state.mm_features:
            pos_info = mm_feature.mm_position
            img_doc_range = pos_info.extract_embeds_range()
            for r in img_doc_range:
                # 若范围长度超出滑动窗口则跳过该范围
                if _bidi_sw is not None and (r[1] - r[0] + 1) > _bidi_sw:
                    continue
                image_doc_ranges.append(r)
        req_idx = self.input_batch.req_id_to_index[req_id]
        req_doc_ranges[req_idx] = image_doc_ranges
​
    # 设置 mm_prefix_range 给所有注意力元数据
    self._set_mm_prefix_range_for_metadata(attn_metadata, req_doc_ranges)

评论区精华

性能优化:避免热路径正则表达式和添加 hasattr 检查 性能

gemini-code-assist 指出 `_clear_mm_prefix_for_full_attn_layers` 在每次前向传播中执行,使用 `re.search` 在循环中创建中间列表导致开销,且设置 `mm_prefix_range` 前需要 `hasattr` 检查以防注意力后端不支持。

结论:最终实现改用 frozenset 预计算索引,并在设置前使用 `hasattr` 检查,已采纳建议。 · 已解决

将模型特定逻辑从 core model runner 中移出 设计

Isotr0py 评论认为在 runner 中添加模型特定守卫是 hacky 的,应该更新 triton kernel 来正确处理 SWA+bidi。lucianommartins 回复同意但希望作为两步方案:先守卫再 kernel 改进。

结论:当前 PR 保留守卫作为临时方案,后续计划改进 kernel。未彻底解决,但 PR 被合并为中间步骤。 · unresolved

守卫应适用于所有结合 bidi 和 SWA 的模型 设计

Isotr0py 指出窗口大小守卫不仅限于 Gemma4,而是所有结合 bidi 和 SWA 的模型,建议移除模型类型检查。

结论:最终实现中未包含模型类型检查,守卫基于 sliding_window 配置通用适用。 · 已解决

可通过 PR#40701 简化守卫实现 other

Isotr0py 提到可以通过 PR#40701 简化守卫逻辑,但没有具体说明。

结论:未深入讨论,可能留给后续 PR。 · 待处理

风险与影响

  • 性能开销_clear_mm_prefix_for_full_attn_layers 在每次 forward 中遍历元数据字典,尽管使用 O(1) 的 frozenset 查找,但遍历所有层名可能带来微小开销。高并发场景需关注。
  • 后向兼容性:如果其他注意力后端不支持 mm_prefix_range 属性,会引发 AttributeError。代码已通过 hasattr 检查缓解。
  • 守卫保守性:滑动窗口守卫跳过超过窗口的图像范围,可能导致大图像失去双向注意力增益。但测试显示在此情况下精度无变化(无回归也无提升),因此作为安全折中。
  • 代码侵入性:模型特定逻辑(守卫)位于通用 model runner 中,增加了维护复杂度。未来应通过 kernel 改进移除。
  • 用户影响:Gemma4 模型用户无需额外配置即可获得 bidi 带来的准确率提升(MMMU-Pro 达 +1.1%),同时大图像场景因守卫保持精度稳定。
  • 系统影响:运行时增加少量开销(检查层索引、守卫过滤),但整体可忽略。
  • 团队影响:引入了两个快速修复点,后续需要跟进 triton kernel 改进以移除守卫,降低维护债务。
热路径性能开销 注意力后端兼容性 模型逻辑侵入核心模块

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论