Prhub

#26468 [Model] Add Qwen3-MoE MTP

原始 PR 作者 ByronHsu 合并时间 2026-05-30 05:31 文件变更 5 提交数 3 评论 7 代码增减 +163 / -4

执行摘要

为 Qwen3-MoE 添加 MTP 推测解码草稿模型

PR 描述提到 "添加 Qwen3-MoE MTP 草稿模型条目用于独立推测解码",目的是让 Qwen3-MoE 模型也能利用 MTP 风格的推测解码加速推理。这是继 GPT-OSS MTP 支持后的又一次模型扩展,参考 #26673 等历史 PR。

该 PR 实现了必要的功能扩展,设计上复用父类 load_weights 的思路值得学习。但 review 中提出的两个问题(权重重命名逻辑和 super.init 跳过)未修复即合并,存在一定风险。建议读者关注未来是否有后续修复 PR,并在自己的部署中注意检查权重加载正确性。

讨论亮点
  1. 权重命名替换风险(correctness):gemini-code-assist 指出当前 mtpmodel 的替换逻辑可能导致 mtp.model.layers.0... 变成 model.model.layers...,建议改用 name.startswith('mtp.') 直接去除前缀。该问题未在合并前解决。

  2. 绕过 super().init 的隐患(design):gemini-code-assist 指出 Qwen3MoeForCausalLMMTP.__init__ 直接调用 nn.Module.__init__ 而非 super().__init__,会跳过基类 attn_cp_size 等属性的初始化,且原地修改 config.num_hidden_layers 可能影响共享配置对象。建议复制 config 并正确调用 super。该问题也未在合并前解决。

实现拆解

  1. 新增草稿模型类:在 python/sglang/srt/models/qwen3_moe_mtp.py 中创建 Qwen3MoeForCausalLMMTP,继承自 Qwen3MoeForCausalLM,重写 __init__forwardload_weights 等方法,实现单层 MTP 前向逻辑。

  2. 修改权重加载:在 python/sglang/srt/models/qwen3_moe.pyload_weights 方法中添加 is_mtp 参数,当为 True 时只加载包含 mtp. 前缀的权重,并将 mtp 前缀重映射为模型内部参数名(如 mtp.fc.weightfc.weightmtp.model.*model.*)。

  3. 配置注册:在 python/sglang/srt/configs/model_config.py_config_draft_model 方法中添加分支,将架构 Qwen3MoeForCausalLM 映射为 Qwen3MoeForCausalLMMTP 并设置 num_nextn_predict_layers = 1

  4. 缓存连续性保障:在 python/sglang/srt/speculative/eagle_worker_v2.pypython/sglang/srt/speculative/eagle_worker.pydraft_forward 方法中,当检测到草稿模型架构为 Qwen3MoeForCausalLMMTP 时,调用 out_cache_loc.contiguous() 确保缓存位置满足融合 RoPE + KV-store 路径的要求。

文件 模块 状态 重要度
python/sglang/srt/models/qwen3_moe_mtp.py 模型层 added 9.07
python/sglang/srt/models/qwen3_moe.py 模型层 modified 7.01
python/sglang/srt/configs/model_config.py 配置 modified 5.44
python/sglang/srt/speculative/eagle_worker_v2.py 推测解码器 modified 5.4
python/sglang/srt/speculative/eagle_worker.py 推测解码器 modified 5.19

关键符号

Qwen3MoeForCausalLMMTP.__init__ Qwen3MoeForCausalLMMTP.forward Qwen3MoeForCausalLMMTP.load_weights Qwen3MoeForCausalLM.load_weights (modified) EagleWorker.draft_forward (contiguity patch) EagleWorkerV2.draft_forward (contiguity patch)

关键源码片段

python/sglang/srt/models/qwen3_moe_mtp.py data-contract

核心新增文件:定义了 Qwen3MoE 的 MTP 草稿模型类,包含初始化、前向和权重加载逻辑。

class Qwen3MoeForCausalLMMTP(Qwen3MoeForCausalLM):
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        # 注意:此处直接调用 nn.Module.__init__,而非 super().__init__,
        # 跳过了基类中 attn_cp_size 等属性的初始化,存在风险。
        nn.Module.__init__(self)
        self.config = config
        # 强制将 num_hidden_layers 设为 1,MTP 模型只需要单层
        config.num_hidden_layers = 1
        self.tp_size = get_tensor_model_parallel_world_size()
        self.quant_config = quant_config
        self.pp_group = get_pp_group()
​
        # MTP 特有的融合层:将 embedding 和 hidden_states 拼接后映射到 hidden_size
        self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
        self.pre_fc_norm_embedding = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.pre_fc_norm_hidden = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
​
        # 复用 Qwen3MoeModel,但只包含 1 层
        self.model = Qwen3MoeModel(config, quant_config, prefix=add_prefix("model", prefix))
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=add_prefix("lm_head", prefix),
            use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
        )
        self.logits_processor = LogitsProcessor(config)
​
        # 以下 mapping 用于复用父类的 load_weights,需与基类保持一致
        self.stacked_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        self.expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_experts,
        )
        self.capture_aux_hidden_states = False
​
    @torch.no_grad()
    def forward(self, input_ids, positions, forward_batch, input_embeds=None, **kwargs):
        # 无 input_embeds 时从 token id 嵌入
        if input_embeds is None:
            input_embeds = self.model.embed_tokens(input_ids)
​
        hidden_states = forward_batch.spec_info.hidden_states # 来自目标模型
​
        if not forward_batch.forward_mode.is_idle():
            # 对两个来源分别做 LayerNorm
            input_embeds = self.pre_fc_norm_embedding(input_embeds)
            hidden_states = self.pre_fc_norm_hidden(hidden_states)
        # 拼接并线性变换
        hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1))
​
        with get_global_expert_distribution_recorder().disable_this_region():
            hidden_states = self.model(input_ids, positions, forward_batch, hidden_states)
​
        return self.logits_processor(input_ids, hidden_states, self.lm_head, forward_batch)
​
    def load_weights(self, weights):
        # 调用父类 load_weights,设置 is_mtp=True 使其只处理 mtp 前缀的权重
        super().load_weights(weights, is_mtp=True)
python/sglang/srt/models/qwen3_moe.py data-contract

修改了权重加载方法,添加 is_mtp 参数支持 MTP 权重分离。

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False):
    # ... 之前的栈式参数映射和专家映射设置不变 ...
    for name, loaded_weight in weights:
        if is_mtp:
            # MTP 模式:只加载包含 "mtp" 的权重
            if "mtp" not in name:
                continue
            # 将特定前缀的权重去掉 "mtp." 前缀(如 mtp.fc.weight → fc.weight)
            if name in [
                "mtp.fc.weight",
                "mtp.pre_fc_norm_embedding.weight",
                "mtp.pre_fc_norm_hidden.weight",
            ]:
                name = name.replace("mtp.", "")
            else:
                # 其他以 mtp 开头的权重要将 "mtp" 替换为 "model"
                # 注意:此处简单替换可能导致 mtp.model.layers 变成 model.model.layers,
                # 正确的做法应是 strip 前缀,但当前 PR 未修复此问题。
                name = name.replace("mtp", "model")
        elif "mtp" in name:
            # 非 MTP 模式:跳过所有 mtp 前缀的权重
            continue
        # 后续的加载逻辑保持不变 ...

评论区精华

MTP 权重命名替换逻辑可能错误 正确性

检查点参数名如 mtp.model.layers.0... 会因替换 mtp 为 model 变成 model.model.layers... 无法匹配

结论:建议用 name.startswith('mtp.') 并直接去除前缀更安全 · 未解决

绕过 super().__init__ 导致基类属性未初始化 设计

直接调用 nn.Module.__init__ 而不是 super().__init__ 会跳过 attn_cp_size 等属性,且修改 config.num_hidden_layers 可能副作用

结论:建议复制 config 并正确调用 super · 未解决

风险与影响

  1. 权重加载错误:上文中提到的 repleace('mtp', 'model') 可能将 mtp.model.layers 错误替换为 model.model.layers,导致权重无法正确加载,草稿模型在训练或推理时出现异常。

  2. 基类属性缺失:跳过 super().__init__ 导致 attn_cp_sizeattn_cp_rank 等上下文并行相关属性未初始化,若启用 context parallel 可能触发 AttributeError

  3. 共享 config 副作用:在 __init__ 中原地修改 config.num_hidden_layers = 1,如果同一 config 对象被其他实例共享,可能导致意外行为。

  4. 缺少测试覆盖:当前没有对应的单元测试或集成测试验证 MTP 草稿模型的正确性,回归风险较高。

用户影响:Qwen3-MoE 模型用户现在可以启用 MTP 风格推测解码,获得一定的推理加速。但必须使用正确的权重文件并确保架构映射生效。
系统影响:仅当指定 Qwen3MoeForCausalLMMTP 架构时才会触发新路径,不影响现有模型逻辑。
团队影响:新增一个模型类,后续需维护与父类的同步;两个 review 指出的潜在问题可能在未来引发 bug,需密切跟踪。

权重重命名逻辑可能错误 super.__init__ 绕过可能导致属性缺失 缺少测试覆盖 config 对象共享修改可能副作用

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论