Prhub

#42250 [Bugfix][Model] Gemma4 MoE routing closure captures per_expert_scale, breaking functional_call substitution

原始 PR 作者 NoeliaBentancor 合并时间 2026-05-14 01:43 文件变更 1 提交数 6 评论 6 代码增减 +7 / -4

执行摘要

修复 Gemma4 MoE 路由闭包捕获参数问题

Issue #42239 报告:Gemma4MoE.init 中的 routing_function 闭包捕获了 per_expert_scale Parameter 的本地引用,导致 torch.func.functional_call 无法替换该参数,影响依赖 functional_call 的模块替换场景(如测试或模型优化)。

值得精读,特别是了解 Python 闭包捕获与 PyTorch functional API 交互的常见陷阱。该 PR 展示了如何通过避免变量捕获来确保参数替换正确工作。

讨论亮点
  • yewentao256:认为改动较小,无需添加专用单元测试。作者同意并移除了测试文件中的新增测试(最终合并仅包含源码改动)。
  • Copilot:检测到测试文件中存在未使用的变量 orig_ids,可能触发 lint 错误。作者已修复。
    • Gemini Code AssistCopilot 的自动 review 均确认了修改的正确性。

实现拆解

  1. 移除本地捕获:删除 Gemma4MoE.__init__ 中的 per_expert_scale = self.per_expert_scale 行。
  2. 运行时读取:将闭包 routing_function 内部所有对 per_expert_scale 的引用替换为 self.per_expert_scale
  3. 添加注释:在闭包定义前增加说明,解释为什么不能捕获本地变量,方便后续维护。
    仅涉及文件 vllm/model_executor/models/gemma4.py,总差异 +7 / -4 行。
文件 模块 状态 重要度
vllm/model_executor/models/gemma4.py 模型层 modified 5.98

关键符号

Gemma4MoE.__init__ routing_function

关键源码片段

vllm/model_executor/models/gemma4.py bugfix

核心修复文件,修改了 Gemma4MoE 类的 __init__ 方法,调整了 routing 闭包的参数捕获方式。

class Gemma4MoE(nn.Module):
    def __init__(self, config, quant_config=None, prefix=""):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_experts = config.num_experts
        # 每个专家的输出缩放因子,融合到路由权重中
        self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))
​
        # 注意 : 此处直接在闭包内通过 self 访问 per_expert_scale,
        # 而不是先捕获到本地变量。这样做是为了确保
        # torch.func.functional_call 的参数替换能生效。
        def routing_function(
            hidden_states, gating_output, topk, renormalize
        ):
            if current_platform.is_cuda_alike() or current_platform.is_xpu():
                return gemma4_fused_routing_kernel_triton(
                    gating_output, topk, self.per_expert_scale
                )
            return gemma4_routing_function_torch(
                gating_output, topk, self.per_expert_scale
            )
​
        self.experts = FusedMoE(
            num_experts=config.num_experts,
            top_k=config.top_k_experts,
            hidden_size=config.hidden_size,
            intermediate_size=getattr(config, "moe_intermediate_size", None),
            renormalize=True,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            custom_routing_function=routing_function,
            activation="gelu_tanh",
        )

评论区精华

是否需要新增单元测试 测试

yewentao256 在 review 中建议改动较小无需新增测试;作者同意并移除了测试文件的改动。

结论:移除了测试,仅保留源码修改。 · 已解决

未使用变量 orig_ids style

Copilot 指出 orig_ids 赋值但未使用,可能触发 ruff lint 错误。

结论:作者已修复该问题。 · 已解决

风险与影响

风险极低。改动本质上将闭包捕获的变量引用改为从 self 属性读取,逻辑等价。唯一的假设是 self.per_expert_scale 在闭包调用时不会变化(与原始捕获语义一致)。被移除的测试原本用于验证 functional_call 替换,回归风险较小,但 reviewer 认为改动足够简单直接。

影响范围小,仅影响 Gemma4 MoE 模块中使用 torch.func.functional_call 的场景。普通推理路径不受任何影响。

闭包捕获行为改变 无新增测试覆盖

关联 Issue

#42239 [Bug]: Gemma4 MoE routing closure captures `per_expert_scale` Parameter, breaking `torch.func.functional_call` substitution

完整报告

参与讨论