Prhub

#41736 [MM][CG] Support ViT CG for Qwen2-VL

原始 PR 作者 johncalesp 合并时间 2026-05-14 01:52 文件变更 4 提交数 9 评论 12 代码增减 +315 / -21

执行摘要

为 Qwen2-VL ViT 添加 CUDA Graph 支持,TTFT 降低约 51%

根据 PR 描述,此变更遵循 PR #35963 的先例,为 Qwen2-VL 启用 ViT CUDA Graphs,旨在减少视觉编码器的 GPU kernel launch 开销,降低首 token 延迟(TTFT)。提供的基准测试表明,无 CUDA Graph 时平均 TTFT 为 9910ms,启用后降至 4768ms。

此 PR 是实现多模态 CUDA Graph 支持的优秀范例,展示了如何为视觉模型适配 SupportsEncoderCudaGraph 协议。对于需要为其他模型启用类似优化的开发者,本 PR 的设计模式值得参考。建议精读 qwen2_vl.pyprepare_encoder_metadataforward 的修改。

讨论亮点

在代码审查中,gemini-code-assist[bot] 指出 prepare_encoder_metadatagrid_thw 为空列表时会触发 IndexError,建议添加 .reshape(-1, 3) 保护。作者 johncalesp 回应称 encoder_cudagraph.py 已经在数据并行(DP)场景下处理了空输入情况,因此无需修改。该讨论被关闭,未采取 bot 的建议。

实现拆解

  1. 实现 SupportsEncoderCudaGraph 接口:在 qwen2_vl.py 中导入 SupportsEncoderCudaGraphEncoderCudaGraphReplayBuffers,并让模型类 Qwen2VLForConditionalGeneration 继承该接口。
  2. 新增 prepare_encoder_metadata 方法:在 Qwen2VLEncoder 类中添加该方法,预计算 rotary position embedding 和 cu_seqlens,并通过 padding 保持形状静态,以便 CUDA Graph 重放。
  3. 修改 forward 方法:在 Qwen2VLEncoder.forward 中添加可选的 encoder_metadata 参数,当提供时直接使用预计算的元数据,避免动态计算。
  4. 实现必备协议方法:在 Qwen2VLForConditionalGeneration 中实现 get_input_modalityget_max_frames_per_videoget_encoder_cudagraph_budget_rangeget_encoder_cudagraph_num_items 等,为 CUDA Graph 调度提供配置。
  5. 更新测试:在 test_vit_cudagraph.py 中添加 Qwen2-VL-2B-Instruct 的测试配置。
  6. 更新示例和文档:在 vision_language_offline.py 中将 qwen2_vl 加入支持列表;在 cuda_graphs_multimodal.md 中添加 Qwen2-VL 的行。
文件 模块 状态 重要度
vllm/model_executor/models/qwen2_vl.py 模型实现 modified 9.05
tests/models/multimodal/generation/test_vit_cudagraph.py 测试 modified 4.94
examples/generate/multimodal/vision_language_offline.py 示例 modified 4.18
docs/design/cuda_graphs_multimodal.md 文档 modified 1.82

关键符号

prepare_encoder_metadata get_encoder_cudagraph_config get_input_modality get_max_frames_per_video get_encoder_cudagraph_budget_range _get_pixel_values_by_modality _get_grid_thw_by_modality get_encoder_cudagraph_num_items

关键源码片段

vllm/model_executor/models/qwen2_vl.py data-contract

主要实现文件,包含所有 CUDA Graph 接口方法和核心 forward 修改。

def prepare_encoder_metadata(
    self,
    grid_thw: list[list[int]],
    *,
    max_batch_size: int | None = None,
    max_frames_per_batch: int | None = None,
    max_seqlen_override: int | None = None,
    device: torch.device | None = None,
) -> dict[str, torch.Tensor]:
    # 默认使用当前设备
    if device is None:
        device = self.device
    # 预计算 rotary position embeddings
    rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw)
    # 从 grid_thw 计算 cu_seqlens
    grid_thw_np = np.array(grid_thw, dtype=np.int32)
    cu_seqlens = np.repeat(
        grid_thw_np[:, 1] * grid_thw_np[:, 2],
        grid_thw_np[:, 0],
    ).cumsum(dtype=np.int32)
    cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
    cu_seqlens = torch.from_numpy(cu_seqlens)
    # 填充 cu_seqlens 以保持形状稳定
    pad_to = (
        max_frames_per_batch if max_frames_per_batch is not None else max_batch_size
    )
    if pad_to is not None:
        num_seqs = len(cu_seqlens) - 1
        if num_seqs < pad_to:
            cu_seqlens = torch.cat((
                cu_seqlens,
                torch.full((pad_to - num_seqs,),
                           cu_seqlens[-1],
                           dtype=cu_seqlens.dtype,
                           device=cu_seqlens.device),
            ))
    # 计算或覆盖 max_seqlen
    if max_seqlen_override is None:
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
    else:
        max_seqlen = torch.tensor(max_seqlen_override, dtype=torch.int32)
    return {
        "rotary_pos_emb_cos": rotary_pos_emb_cos,
        "rotary_pos_emb_sin": rotary_pos_emb_sin,
        "cu_seqlens": cu_seqlens.to(device=device, non_blocking=True),
        "max_seqlen": max_seqlen,
    }def forward(
    self,
    x: torch.Tensor,
    grid_thw: torch.Tensor | list[list[int]],
    *,
    encoder_metadata: dict[str, torch.Tensor] | None = None,
) -> torch.Tensor:
    x = x.to(device=self.device, dtype=self.dtype)
    x = self.patch_embed(x)
    if encoder_metadata is None:
        encoder_metadata = self.prepare_encoder_metadata(
            grid_thw if isinstance(grid_thw, list) else grid_thw.tolist()
        )
    cu_seqlens = encoder_metadata["cu_seqlens"]
    max_seqlen = encoder_metadata["max_seqlen"]
    # 使用预计算元数据继续前向传播 ...

评论区精华

空 grid_thw 导致 IndexError 的风险 正确性

gemini-code-assist[bot] 指出 prepare_encoder_metadata 在 grid_thw 为空列表时会因数组维度不足抛出 IndexError,建议添加 .reshape(-1, 3)。

结论:作者回应 encoder_cudagraph.py 已经在 DP 场景下处理了空输入,无需修改。PR 合并时未采纳 bot 建议。 · 已解决

风险与影响

虽然作者声称 DP 场景已由上层处理,但 prepare_encoder_metadata 方法并未显式防御空 grid_thw,在非 DP 场景或未来重构后可能引发崩溃。另外,CUDA Graph 依赖静态形状,如果实际输入的视觉 token 数量超过预分配的 padding 大小,将导致 CUDA Graph 重放失败。测试中禁用了 prefix caching,启用 CG 与 prefix caching 的兼容性尚未验证。

仅影响 Qwen2-VL 模型用户,且要求使用 CUDA 后端并开启 cudagraph_mm_encoder 配置。性能提升明显,TTFT 降低约 52%,总吞吐量提升约 10%。不影响其他模型或默认行为。

核心路径变更 空输入未显式防御 与 prefix caching 兼容性未知

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论