Prhub

#24436 [Gemma 4] Adding MTP support

原始 PR 作者 kpham-sgl 合并时间 2026-05-08 05:08 文件变更 11 提交数 34 评论 26 代码增减 +1949 / -7

执行摘要

为 Gemma4 添加 FROZEN_KV_MTP 投机解码算法

Each Gemma 4 target ships with a small 'assistant' checkpoint trained for MTP. This PR introduces a new speculative algorithm — FROZEN_KV_MTP — that runs the assistant against the target's KV cache (the assistant has no KV of its own and a recurrent hidden state across draft steps, so it does not fit cleanly under EAGLE/EAGLE3 or NEXTN).

此 PR 对于 Gemma 4 用户至关重要,值得精读。设计上选择冻结 KV 方案而非传统 EAGLE 是合理的。关注点是 TP>1 支持尚未完全验证,数值掩码稳定性有待改进。建议后续跟进 TP 测试和掩码修复。

讨论亮点
  1. TP>1 支持:gemini-code-assist[bot] 指出助理模型的嵌入查找和 centroid 掩码 gather 在 TP>1 时会崩溃。作者 kpham-sgl 回应 self.target_embed_weight 不是 TP 分片的,所以问题不存在,但未完全解决其他 gather 问题。
  2. 数值掩码稳定性:review 指出 selected_logits.min() - 1.0 作为掩码值可能因 -inf 产生 nan,建议使用 -1e10。作者未回复,该问题可能未被采纳。
  3. 服务器参数副作用:review 指出 server_args.context_length 原地突变是坏实践。作者回应"Same pattern in Eagle",沿用现有模式。
  4. 信任远程代码:review 指出 trust_remote_code=True 被硬编码,应尊重用户 --trust-remote-code 标志。未看到修改。
  5. 测试注册:Qiaolin-Yu 询问是否将手动测试注册到 CI。作者回复将写一个较小的测试,后续在提交中添加了 CI 测试。

实现拆解

步骤 1:助理模型定义
python/sglang/srt/models/gemma4_mtp.py 中新增 Gemma4AssistantForCausalLM,继承自 Gemma4ForCausalLM。模型通过目标嵌入(target embed)和预投影/后投影实现循环隐藏状态,拥有自己的 lm_headbind_frozen_kv_context 方法将助理逻辑层映射到目标物理层。

步骤 2:投机算法与数据模型
spec_info.py 中新增 FROZEN_KV_MTP 枚举值和 is_frozen_kv_mtp 判别方法。frozen_kv_mtp_info.py 定义了 FrozenKVMTPContext(存储目标 KV 池和层映射)、FrozenKVMTPDraftInputFrozenKVMTPVerifyInput,后两者复用 EAGLE 的调度合约。

步骤 3:工作器实现
frozen_kv_mtp_worker.py 中的 FrozenKVMTPWorker 继承 TpModelWorker,负责草稿循环。它借用目标的 req_to_token_pool 和 KV 分配器(只读),强制禁用 CUDA 图嵌入和覆盖调度。_resolve_draft_backend_type 方法选择注意力后端。FrozenKVMTPWorkerV2 暂未实现,使用时必须 --disable-overlap-schedule

步骤 4:CUDA 图支持
frozen_kv_mtp_cuda_graph_runner.py 中的 FrozenKVMTPCudaGraphRunner 支持 topk=1 的简单模式和 topk>1 的树验证模式。_capture_graph 方法捕捉整个草稿步骤,_replay 实现低延迟回放。同时定义 FrozenKVMTPInputBuffers 数据类传递输入张量。

步骤 5:服务器配置与自动提升
server_args.py 新增 _resolve_speculative_algorithm_alias,当草稿模型架构为 Gemma4AssistantForCausalLM 时自动将 NEXTN/EAGLE 提升为 FROZEN_KV_MTP,并拒绝 EAGLE3。当算法激活时,强制禁用 overlap_scheduler 和 mixed_chunked_prefill,默认 max-running-requests 为 48。

步骤 6:现有模型适配
gemma4_causal.pygemma4_mm.py 暴露 get_embed_and_head 方法,供助理在加载时重新绑定到目标输入嵌入。hf_transformers/config.py 识别 gemma4_assistant 类型。

文件 模块 状态 重要度
python/sglang/srt/models/gemma4_mtp.py 模型层 added 9.17
python/sglang/srt/speculative/frozen_kv_mtp_worker.py 投机解码 added 8.99
python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py 投机解码 added 8.89
python/sglang/srt/speculative/frozen_kv_mtp_utils.py 投机解码 added 8.89
python/sglang/srt/speculative/frozen_kv_mtp_info.py 投机解码 added 8.52
python/sglang/srt/speculative/frozen_kv_mtp_worker_v2.py 投机解码 added 6.94
python/sglang/srt/server_args.py 配置 modified 7.21
python/sglang/srt/speculative/spec_info.py 投机解码 modified 6.71
python/sglang/srt/models/gemma4_mm.py 模型层 modified 5.74
python/sglang/srt/models/gemma4_causal.py 模型层 modified 5.43
python/sglang/srt/utils/hf_transformers/config.py 配置 modified 4.49

关键符号

Gemma4AssistantForCausalLM.__init__ Gemma4AssistantForCausalLM.get_embed_and_head Gemma4AssistantForCausalLM.bind_frozen_kv_context FrozenKVMTPWorker.__init__ FrozenKVMTPWorker.draft_model_runner FrozenKVMTPCudaGraphRunner.__init__ FrozenKVMTPCudaGraphRunner._capture_graph FrozenKVMTPCudaGraphRunner._replay _resolve_speculative_algorithm_alias is_frozen_kv_mtp frozen_kv_target_view expand_for_topk_draft set_frozen_kv_positions

关键源码片段

python/sglang/srt/models/gemma4_mtp.py data-contract

核心助理模型定义,包括 Gemma4AssistantForCausalLM 及其冻结 KV 绑定逻辑

class Gemma4AssistantForCausalLM(Gemma4ForCausalLM):
    """Gemma 4 MTP 助理模型:使用目标嵌入 + 循环隐藏状态,拥有自己的 lm_head。"""
​
    base_model_prefix = "model"
​
    def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> None:
        # 深拷贝文本配置并禁用 KV 共享(助理不管理 KV)
        text_config = copy.deepcopy(_get_text_config(config))
        text_config.num_kv_shared_layers = 0
        PreTrainedModel.__init__(self, config=text_config)
        self.assistant_config = config
        self.config = text_config
        self.quant_config = quant_config
​
        self.vocab_size = text_config.vocab_size
        self.hidden_size = text_config.hidden_size
        # backbone_hidden_size 来自助理配置,是目标嵌入的维度
        self.backbone_hidden_size = config.backbone_hidden_size
        # 目标嵌入缩放因子,用于将目标嵌入投影到助理隐藏空间
        self.target_embed_scale = self.backbone_hidden_size ** 0.5
        self.use_ordered_embeddings = getattr(config, "use_ordered_embeddings", False)
        self.centroid_intermediate_top_k = int(getattr(config, "centroid_intermediate_top_k", 32))
​
        # 目标嵌入权重将在加载时由 bind_frozen_kv_context 绑定
        self.target_embed_weight = None
​
        # 预投影:拼接目标嵌入和循环隐藏状态
        self.pre_projection = ReplicatedLinear(2 * self.backbone_hidden_size, self.hidden_size, bias=False, quant_config=None)
        # 助理骨干网络,复用 Gemma4TextModel
        self.model = Gemma4TextModel(config=text_config, quant_config=quant_config, prefix=add_prefix("model", prefix))
        # 后投影:将骨干输出映射回 backbone_hidden_size
        self.post_projection = ReplicatedLinear(self.hidden_size, self.backbone_hidden_size, bias=False, quant_config=None)
​
        # 语言模型头:如果词汇嵌入共享则绑定到 embed_tokens
        if text_config.tie_word_embeddings:
            self.lm_head = self.model.embed_tokens
        else:
            self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
        self.logits_processor = LogitsProcessor(text_config, skip_all_gather=True)
​
        # 如果启用有序嵌入(centroid),则配置 centroid 词汇头 ...(后续代码省略)
python/sglang/srt/speculative/frozen_kv_mtp_worker.py dependency-wiring

FrozenKVMTPWorker 实现草稿生成循环,协调目标 KV 访问和树验证

class FrozenKVMTPWorker(TpModelWorker):
    """Frozen-KV MTP 工作器。助理只读目标 KV,重复使用 EAGLE 的验证合约。"""
​
    def __init__(self, server_args, gpu_id, tp_rank, dp_rank, moe_ep_rank, attn_cp_rank, moe_dp_rank, nccl_port, target_worker):
        self.server_args = server_args
        self.topk = server_args.speculative_eagle_topk
        self.speculative_num_steps = server_args.speculative_num_steps
        self.target_worker = target_worker
​
        # 确保算法类型正确
        assert self.speculative_algorithm.is_frozen_kv_mtp()
​
        # 助理上下文长度必须与目标一致(副作用警告)
        server_args.context_length = target_worker.model_runner.model_config.context_len
​
        # 禁用 CUDA 图(我们自己管理)
        backup_disable_cuda_graph = server_args.disable_cuda_graph
        server_args.disable_cuda_graph = True
​
        # 复用目标的内存池(只读)
        self.req_to_token_pool, self.token_to_kv_pool_allocator = target_worker.get_memory_pool()
​
        # 配置草稿注意力后端 ...(后续代码省略)
python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py dependency-wiring

CUDA 图运行器实现低延迟草稿步骤捕获与回放

@dataclass
class FrozenKVMTPInputBuffers(ForwardInputBuffers):
    req_pool_indices: torch.Tensor
    positions: torch.Tensor
    seq_lens: torch.Tensor
    hidden_states: torch.Tensor # 循环隐藏状态
    topk_p: torch.Tensor # 树验证概率
    topk_index: torch.Tensor # 树验证索引
​
​
class FrozenKVMTPCudaGraphRunner:
    """CUDA 图运行器,用于 Frozen-KV MTP 的循环草稿步骤。"""
​
    def __init__(self, frozen_kv_mtp_worker):
        self.model_runner = frozen_kv_mtp_worker.draft_model_runner
        self.speculative_num_steps = self.model_runner.server_args.speculative_num_steps
        self.topk = self.model_runner.server_args.speculative_eagle_topk
        self.num_tokens_per_bs = self.topk
        # 获取待捕获的 batch sizes
        self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner, self.num_tokens_per_bs)
        self.max_bs = max(self.capture_bs)
        self.max_num_token = self.max_bs * self.num_tokens_per_bs
        # 初始化输入缓冲区并分配 CUDA 张量 ...(后续代码省略)

评论区精华

TP>1 时嵌入查找错误 正确性

gemini-code-assist[bot] 指出 `self.target_embed_weight` 在 TP>1 时是分片的,使用全局 input_ids 嵌入会产生错误结果,需要 all_reduce 或使用并行嵌入层。

结论:作者回复 target_embed_weight 不是 TP 分片的,忽略该问题。但 centroid 掩码 gather 和词汇重排序在 TP>1 时可能仍会崩溃。 · unresolved

centroid 掩码 gather 在 TP>1 时崩溃 正确性

gemini-code-assist[bot] 指出 lm_head 若绑定到 embed_tokens,其权重是 VocabParallelEmbedding 的分片,view 操作会因分片大小不匹配而失败。

结论:作者回复 'ditto',认为是相同原因忽略。 · unresolved

词汇重排序在 TP>1 时索引越界 正确性

gemini-code-assist[bot] 指出 ordering 包含全局索引,但 lm_head_w 是词汇分片,直接索引会越界。

结论:作者回复 'ditto'。 · unresolved

trust_remote_code 硬编码 安全

gemini-code-assist[bot] 指出 `trust_remote_code=True` 被硬编码在 `_resolve_speculative_algorithm_alias` 中,应使用用户提供的 `--trust-remote-code` 标志。

结论:未看到修改,可能维持硬编码。 · unresolved

掩码值数值不稳定 正确性

gemini-code-assist[bot] 建议将 `selected_logits.min() - 1.0` 替换为 `-1e10` 或 `-torch.inf`,避免 -inf 导致 nan。

结论:作者未回复,可能未采纳。 · unresolved

server_args.context_length 原地突变 设计

gemini-code-assist[bot] 指出突变 `server_args.context_length` 是有副作用的,建议改用局部变量。

结论:作者回复 'Same pattern in Eagle',沿用现有模式。 · addressed

CUDA 图运行器使用通用 Exception style

gemini-code-assist[bot] 建议改用更具体的异常类型如 RuntimeError。

结论:未看到修改。 · unresolved

测试注册到 CI 测试

Qiaolin-Yu 询问是否将手动 GSM8K 测试注册到 CI。作者回复将写一个较小的测试,后续提交中添加了 stage b CI 测试(后因 transformers 版本问题暂时移除)。

结论:已添加 CI 测试但后续因依赖版本暂移除,最终状态待定。 · 已解决

风险与影响

  1. TP>1 兼容性:尽管作者声称某些张量不是 TP 分片的,但 centroid 掩码 gather 和词汇重排序逻辑在 TP>1 时仍可能崩溃。需要针对 TP>1 进行专门测试。
  2. 数值稳定性:掩码值可能因 -inf 导致 nan,影响采样质量。
  3. 服务器参数副作用server_args.context_length 被突变,可能引起日志混乱和调试困难。
  4. 通用异常:CUDA 图运行器中抛出的泛型 Exception 可能导致错误信息模糊。
  5. 覆盖调度器强制禁用:FROZEN_KV_MTP 强制禁用覆盖调度器,这可能影响与其他工作负载的兼容性。
  1. 用户影响:Gemma 4 用户可通过简单参数启用 MTP 加速,获得接近无损的准确度。但需要手动升级 Transformers 和设置正确参数。
  2. 系统影响:新投机算法增加了调度器分支,强制禁用某些优化(overlap schedule、mixed chunked prefill),可能影响整体性能。
  3. 团队影响:需要维护新的算法代码路径,特别是 TP 扩展和后续性能优化。
TP>1 兼容性未验证 数值掩码稳定性风险 服务器参数副作用 通用异常抛出 覆盖调度器强制禁用

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论