Prhub

#43241 [Model Runner V2][Spec Decode] Add Gemma4 MTP support

原始 PR 作者 TheEpicDolphin 合并时间 2026-06-04 08:51 文件变更 14 提交数 1 评论 26 代码增减 +1243 / -942

执行摘要

为 MRV2 添加 Gemma4 MTP 支持并重构 speculator 架构

PR body 明确指出 Gemma4 MTP 在 MRV1 中已支持(见 #41745),但 MRV2 中缺失。关键功能需求包括:draft 步使用常位置(constant positions)、draft 层复用目标模型同一 attention group 的最终 target 层的 KV cache、draft forward 返回元组 (draft_hidden_states, backbone_hidden_states)。原有 EagleSpeculator 全部代码将变得不可维护,因此决定重构 speculator 类体系以支持这些特性。

值得精读。本 PR 展示了一次成功的大规模重构实践:通过属性化钩子和模板方法模式将通用逻辑与模型特定逻辑分离,代码行数净减约 700 行。设计决策(如属性 vs 继承、钩子的粒度)值得在类似扩展点设计中借鉴。关注点:测试覆盖需后续补齐,建议阅读 merge 后的 #44253(修复了 capture 阶段 attention state 问题)。

讨论亮点

Review 中有以下主要讨论点:

  • init_attn_backend 返回值解包错误 (gemini-code-assist[bot], depthfirst-app[bot]):init_attn_backend 返回4个值但代码只解包3个,会导致运行时 ValueError。作者确认已修复。
  • 设计未遵循之前讨论的方案 (benchislett):认为此设计更接近 MRV1 的参数化 propose 方法而非分离式设计。但后续 approve 表明达成共识。
  • 动态导入的动机 (benchislett):询问为何使用动态 import,作者解释延续现有模式且避免加载未使用的 speculator 模块。
  • 测试覆盖不足 (benchislett):询问测试覆盖情况,作者承诺后续 PR 添加参数化测试(MRV1+MRV2)。
  • CUDA graph 文件移动 (benchislett):建议将 cudagraph 工具移入 autoregressive 目录,作者执行。
  • MTPSpeculator 继承关系 (benchislett):建议考虑继承 EagleSpeculator,作者选择保持显式独立以避免混淆。

实现拆解

  1. 新建基础抽象层 (spec_decode/speculator.py):定义 BaseSpeculator(包含 init_cudagraph_managercapturepropose 抽象方法)和 DraftModelSpeculator(继承 BaseSpeculator,管理共享状态如 DP 配置、attention 后端初始化、模型加载等)。后者声明 load_draft_model 为扩展点。

  2. 提取通用自回归推测循环 (autoregressive/speculator.py):AutoRegressiveSpeculator 继承 DraftModelSpeculator,封装 CUDA graph 捕获、draft propose 流程(prefill + 多步 decode)、Triton 输入预处理内核。通过三个属性钩子 (advance_draft_positions, model_returns_tuple, advance_draft_positions) 和两个方法钩子 (load_draft_model, sample_draft) 实现模型特定行为的可扩展性。

  3. 精简 EagleSpeculator (eagle/speculator.py):原 900+ 行缩减为 18 行,仅继承 AutoRegressiveSpeculator 并覆盖 load_draft_model 调用 load_eagle_model

  4. 新增 Gemma4Speculator (gemma4/speculator.py):覆盖 advance_draft_positions 返回 False(常位置)、model_returns_tuple 返回 True(元组输出)、load_draft_model(创建 draft vllm config 保留目标模型 TRITON_ATTN 后端、建立 KV 共享、共享嵌入层)。

  5. 新增 MTPSpeculator (mtp/speculator.py):覆盖 model_returns_tuple 返回 False(单 tensor)、load_draft_model 调用 load_eagle_model,适配 DeepSeek 等标准 MTP 模型。

  6. CUDA graph 管理器重命名与移动PrefillEagleCudaGraphManagerPrefillSpeculatorCudaGraphManagerDecodeEagleCudaGraphManagerDecodeSpeculatorCudaGraphManager,并移至 autoregressive/cudagraph_utils.py

  7. 动态导入调整 (spec_decode/__init__.py):根据配置调用 use_gemma4_mtp()/use_mtp() 动态导入对应 speculator 类。

  8. 辅助清理:移除 vllm/v1/attention/backends/utils.py 中的 get_num_attention_heads_from_layers(已不再使用),调整 flashinfer.pytriton_attn.py 中对 get_num_attention_heads_from_layers 的引用。

文件 模块 状态 重要度
vllm/v1/worker/gpu/spec_decode/autoregressive/speculator.py 自回归推测 added 8.99
vllm/v1/worker/gpu/spec_decode/speculator.py 推测器接口 added 8.9
vllm/v1/worker/gpu/spec_decode/gemma4/speculator.py Gemma4 推测 added 8.68
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py Eagle 推测 modified 8.69
vllm/v1/worker/gpu/spec_decode/mtp/speculator.py MTP 推测 added 7.37
vllm/v1/worker/gpu/spec_decode/autoregressive/cudagraph_utils.py CUDA Graph renamed 7.18
vllm/v1/worker/gpu/spec_decode/__init__.py 入口点 modified 6.29
vllm/v1/worker/gpu/model_runner.py 模型运行器 modified 5.31
vllm/v1/attention/backends/utils.py 注意力后端 modified 6.04
vllm/v1/attention/backends/triton_attn.py 注意力后端 modified 5.12
vllm/v1/attention/backends/flashinfer.py 注意力后端 modified 4.62

关键符号

BaseSpeculator.init_cudagraph_manager BaseSpeculator.capture BaseSpeculator.propose DraftModelSpeculator.__init__ DraftModelSpeculator.load_draft_model DraftModelSpeculator.load_model AutoRegressiveSpeculator.__init__ AutoRegressiveSpeculator.advance_draft_positions AutoRegressiveSpeculator.model_returns_tuple AutoRegressiveSpeculator.init_cudagraph_manager AutoRegressiveSpeculator.capture AutoRegressiveSpeculator.propose AutoRegressiveSpeculator.sample_draft Gemma4Speculator.advance_draft_positions Gemma4Speculator.model_returns_tuple Gemma4Speculator.load_draft_model Gemma4Speculator._create_draft_vllm_config Gemma4Speculator._setup_gemma4_kv_sharing Gemma4Speculator._share_embeddings MTPSpeculator.model_returns_tuple MTPSpeculator.load_draft_model EagleSpeculator.load_draft_model

关键源码片段

vllm/v1/worker/gpu/spec_decode/autoregressive/speculator.py core-logic

核心新增文件,封装通用自回归推测循环,定义了可扩展的属性和方法钩子

class AutoRegressiveSpeculator(DraftModelSpeculator):
    def __init__(self, vllm_config: VllmConfig, device: torch.device):
        super().__init__(vllm_config, device)
        self.hidden_states = torch.zeros(
            self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device)
        self.current_draft_step = torch.tensor(0, dtype=torch.int64, device=device)
        self.last_token_indices = torch.zeros(
            self.max_num_reqs, dtype=torch.int64, device=device)
        # 支持多模态输入时额外分配 embed 缓存
        self.supports_mm_inputs = MULTIMODAL_REGISTRY.supports_multimodal_inputs(
            self.draft_model_config)
        if self.supports_mm_inputs:
            self.inputs_embeds = torch.zeros(
                self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device)
        self.prefill_cudagraph_manager: PrefillSpeculatorCudaGraphManager | None = None
        self.decode_cudagraph_manager: DecodeSpeculatorCudaGraphManager | None = None
​
    # 属性钩子:每个 draft 步骤是否推进位置(True 为 Eagle/MTP 产生新 KV,False 为 Gemma4 Q-only)
    @property
    def advance_draft_positions(self) -> bool:
        return True
​
    # 属性钩子:draft forward() 是否返回 (last_hidden_states, hidden_states) 元组
    @property
    def model_returns_tuple(self) -> bool:
        return True
​
    def init_cudagraph_manager(self, cudagraph_mode: CUDAGraphMode) -> None:
        # 初始化 prefill CUDA Graph manager(覆盖 draft 步 0 到 num_speculative_steps)
        self.prefill_cudagraph_manager = PrefillSpeculatorCudaGraphManager(
            self.vllm_config, self.device, cudagraph_mode,
            self.num_speculative_steps + 1)
        # PIECEWISE 模式在 decode 阶段不支持,降级为 FULL_DECODE_ONLY 或 NONE
        if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL:
            cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
        else:
            cudagraph_mode = CUDAGraphMode.NONE
        self.decode_cudagraph_manager = DecodeSpeculatorCudaGraphManager(
            self.vllm_config, self.device, cudagraph_mode, decode_query_len=1)
​
    def capture(self, attn_states: dict[BatchExecutionDescriptor, AttentionStatePair]):
        logger.info("Capturing model for speculator...")
        self.last_token_indices.zero_()
        # 捕获 prefill 阶段(model forward + compute_logits + sample)
        assert self.prefill_cudagraph_manager is not None
        if self.prefill_cudagraph_manager.use_breakable_cg:
            self.prefill_cudagraph_manager.init_breakable_cg_runner(self.model)
        self.prefill_cudagraph_manager.capture(
            self._prefill, attn_states,
            progress_bar_desc="Capturing prefill CUDA graphs")
        if self.num_speculative_steps == 1:
            return
        # 捕获 decode 阶段(model forward + sample + update_draft_inputs)
        self.decode_cudagraph_manager.capture(
            self._decode, attn_states,
            progress_bar_desc="Capturing decode CUDA graphs")
vllm/v1/worker/gpu/spec_decode/speculator.py core-logic

定义基础抽象和共享状态,是重构后的根接口

from abc import ABC, abstractmethodclass BaseSpeculator(ABC):
    @abstractmethod
    def init_cudagraph_manager(self, cudagraph_mode: CUDAGraphMode) -> None: pass
​
    @abstractmethod
    def capture(self, attn_states: dict[BatchExecutionDescriptor, AttentionStatePair]) -> None: pass
​
    @abstractmethod
    def propose(self, input_batch: InputBatch, attn_metadata, slot_mappings,
                last_hidden_states, aux_hidden_states, num_sampled, num_rejected,
                last_sampled, next_prefill_tokens, temperature, seeds,
                num_tokens_across_dp=None, dummy_run=False,
                skip_attn_for_dummy_run=False, mm_inputs=None,
                is_profile=False) -> torch.Tensor: passclass DraftModelSpeculator(BaseSpeculator):
    def __init__(self, vllm_config: VllmConfig, device: torch.device):
        # 从配置中提取推测相关参数
        self.vllm_config = vllm_config
        self.device = device
        speculative_config = vllm_config.speculative_config
        assert speculative_config is not None
        self.num_speculative_steps = speculative_config.num_speculative_tokens
        self.draft_model_config = speculative_config.draft_model_config
        self.hidden_size = self.draft_model_config.get_hidden_size()
        # 若存在 HC 多重隐变量(如 DeepSeek V4),扩大 hidden_size
        hc_mult = getattr(self.draft_model_config.hf_config, "hc_mult", 1)
        self.hidden_size = self.hidden_size * hc_mult
        self.vocab_size = self.draft_model_config.get_vocab_size()
        self.dtype = vllm_config.model_config.dtype
        # 分配共享缓冲区
        self.input_buffers = InputBuffers(max_num_reqs=self.max_num_reqs,
                                          max_num_tokens=self.max_num_tokens, device=device)
        self.idx_mapping = torch.zeros(self.max_num_reqs, dtype=torch.int32, device=device)
        self.temperature = torch.zeros(self.max_num_reqs, dtype=torch.float32, device=device)
        self.seeds = torch.zeros(self.max_num_reqs, dtype=torch.int64, device=device)
        self.draft_tokens = torch.zeros(self.max_num_reqs, self.num_speculative_steps,
                                        dtype=torch.int64, device=device)
        self.arange = torch.arange(self.max_num_reqs + 1, dtype=torch.int32, device="cpu")
        pass
​
    def load_model(self, target_model: nn.Module, target_attn_layer_names: set[str]):
        # 调用子类的 load_draft_model 并缓存模型
        self.model = self.load_draft_model(target_model, target_attn_layer_names)
        self._validate_target_attention_layer_names(target_attn_layer_names)
        return self.model
​
    def load_draft_model(self, target_model: nn.Module, target_attn_layer_names: set[str]) -> nn.Module:
        raise NotImplementedError
vllm/v1/worker/gpu/spec_decode/gemma4/speculator.py core-logic

新增 Gemma4 专用推测器,实现常位置、KV 共享和特殊配置传播

class Gemma4Speculator(AutoRegressiveSpeculator):
    # Gemma4 MTP 是 Q-only,不写入新 KV,固定序列长度
    @property
    def advance_draft_positions(self) -> bool:
        return False
​
    # draft forward 返回 (draft_hidden_states, backbone_hidden_states)
    @property
    def model_returns_tuple(self) -> bool:
        return True
​
    def load_draft_model(self, target_model: nn.Module,
                         target_attn_layer_names: set[str]) -> nn.Module:
        # 1. 创建保留目标模型 attention backend 的 vllm config
        draft_vllm_config = self._create_draft_vllm_config()
        with set_model_tag("eagle_head"):
            draft_model = get_model(vllm_config=draft_vllm_config,
                                    model_config=self.speculative_config.draft_model_config,
                                    load_config=self.speculative_config.draft_load_config)
        # 2. 建立 draft 与 target 之间的 KV 共享
        self._setup_gemma4_kv_sharing(draft_model, target_attn_layer_names)
        # 3. 让 draft 复用 target 的 embed_tokens
        self._share_embeddings(draft_model, target_model)
        return draft_model
​
    def _create_draft_vllm_config(self) -> VllmConfig:
        # 生成 draft 的 vllm config,强制继承 target 的 attention backend
        # Gemma4 强制 TRITON_ATTN(因为不兼容 head_dim=256/512 冲突),
        # 但 DraftModelSpeculator 基类会重置 backend 为 None,必须重写。
        draft_vllm_config = replace(self.vllm_config,
                                    model_config=self.speculative_config.draft_model_config)
        target_backend = self.vllm_config.attention_config.backend
        if target_backend is not None:
            draft_vllm_config = replace(draft_vllm_config,
                                        attention_config=replace(
                                            draft_vllm_config.attention_config,
                                            backend=target_backend))
        return draft_vllm_config
​
    def _setup_gemma4_kv_sharing(self, model: nn.Module,
                                  target_attn_layer_names: set[str]):
        # 对每个 draft decoder 层,找到对应的 target attention 层并分享 KV 缓存
        draft_config = self.speculative_config.draft_model_config.hf_config
        draft_text_config = draft_config.get_text_config()
        target_text_config = self.vllm_config.model_config.hf_config
        target_layer_types = getattr(target_text_config, "layer_types", [])
        if not (hasattr(model, "model") and hasattr(model.model, "layers")):
            return
        target_num_kv_shared = getattr(target_text_config, "num_kv_shared_layers", 0)
        num_non_shared = len(target_layer_types) - target_num_kv_shared
        type_to_target_indices: dict[str, list[int]] = defaultdict(list)
        for idx, lt in enumerate(target_layer_types[:num_non_shared]):
            type_to_target_indices[lt].append(idx)
        # ... 遍历 draft 层并映射到对应 target 层,实现 KV 指针替换
        # ( 完整逻辑见源码 )

评论区精华

init_attn_backend 返回值解包数量错误 正确性

gemini-code-assist[bot] 和 depthfirst-app[bot] 指出 init_attn_backend 返回 4 个值但代码只解包 3 个变量,会导致运行时 ValueError。

结论:作者已修复解包数量。 · 已解决

设计未遵循之前讨论的方案 设计

benchislett 指出此设计更接近 MRV1 的参数化 propose 方法,而非分离式方案,但后续 approve。

结论:达成共识,保留当前设计。 · 已解决

动态导入动机 question

benchislett 询问为何使用动态 import,作者解释延续现有模式并避免加载未使用的 speculator 模块。

结论:不做修改。 · 已解决

测试覆盖不足 测试

benchislett 询问测试覆盖情况,担心未来被破坏。作者承诺后续 PR 添加参数化测试(MRV1+MRV2)。

结论:待后续补充。 · 待处理

CUDA graph 管理类重命名与移动 设计

benchislett 建议将 cudagraph 工具移入 autoregressive 目录,作者同意并执行。

结论:已完成重命名和移动。 · 已解决

MTPSpeculator 继承关系 设计

benchislett 建议考虑继承 EagleSpeculator,作者认为保持独立更清晰。

结论:保持独立。 · 已解决

DraftModelSpeculator 命名 style

benchislett 认为该名字可以更好。

结论:暂未改动。 · unresolved

风险与影响

  1. 回归风险:EagleSpeculator 从 900+ 行缩减至 18 行,大量逻辑移至父类,可能引入回归。虽然测试尚未添加,但合并后后续 commit 在修复相关 bug(如 #44253)时可能会暴露问题。
  2. KV 共享兼容性:Gemma4Speculator 的 KV 共享机制假设 target 模型具有特定 layer type 配置(num_kv_shared_layers, layer_types),若目标模型没有这些属性则可能跳过共享。需确保在非 Gemma4 场景下安全。
  3. 配置一致性_create_draft_vllm_config 强制传递目标模型 attention backend 给 draft 模型,若目标未设置 backend(None)则 draft 可能 fallback 到不支持 KV share 的实现。
  4. init_attn_backend 返回值虽已修复,但类似解包错误可能在其他路径仍然存在。
  5. 缺少测试覆盖:当前无测试文件变更,无法保证变更的稳定性,后续需补全。

用户影响:Gemma4 模型用户在启用 MRV2 后可获得 MTP 推测解码加速(benchmark 显示与 MRV1 性能持平)。系统影响:speculator 架构清晰分层,降低添加新推测模型的门槛。团队影响:Eagle 和 MTP 相关代码更简洁,但需要对新抽象有理解成本。无 Breaking changes(所有现有模型接口保持兼容)。

缺少测试覆盖 核心架构重构 KV 共享兼容性 init_attn_backend 依赖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论