Prhub

#40786 Fix PP in Gemma4

原始 PR 作者 SKRohit 合并时间 2026-04-29 18:17 文件变更 1 提交数 5 评论 11 代码增减 +9 / -16

执行摘要

修复 Gemma4 PP 中 residual 和 per_layer_inputs 同步

PR body 指出:"Fixes PP in Gemma4 for Gemma4ForCausalLM. per_layers_inputs needs to be passed only for smaller models and residual is None from decoder layers so need not to be synchronized among different ranks." 原始实现中不必要的张量同步可能导致错误或性能开销。

建议阅读此 PR 以了解 Gemma4 在 PP 下的张量同步设计,特别是 IntermediateTensors 如何按需传递。对于有类似 PP + PLE 实现的模型开发者,这是一个值得关注的决策案例——如何平衡泛化与模型特定优化。

讨论亮点
  • gemini-code-assist[bot] 评论:指出 forwardif per_layer_inputs is not None: 条件在非首 rank 始终为 False(因为参数默认 None),认为这会禁用 PLE 功能。但 PR 作者和 reviewer 未采纳此建议,理由是 per_layer_inputs 仅对较小模型需要。
  • 作者提供 GSM8k 验证:PP=4 时准确率 0.954,证明改动不影响正确性。
  • DarkLight1337 批准:要求运行 lm-eval 验证后确认无误。

实现拆解

  1. 移除 _make_empty_intermediate_tensors 中的 residual 张量(文件 vllm/model_executor/models/gemma4.py):在创建空的 IntermediateTensors 时不再包含 "residual" 项,因为 Gemma4 的 decoder layers 不输出独立 residual。
  2. 调整 forward 中非首 rank 分支:不再无条件从 intermediate_tensors 读取 residualper_layer_inputs,改为仅当 per_layer_inputs 参数非 None(即首 rank 传递了该值)时才从 intermediate_tensors 中读取;residual 在两者中都统一设为 None。
  3. 调整最终返回的 IntermediateTensors:在非末 rank 返回时,只打包 hidden_states 和(非 None 的)per_layer_inputs,不再包含 residual。
    以上改动确保 PP 各 rank 间只传递必要张量,符合 Gemma4 的设计预期。
文件 模块 状态 重要度
vllm/model_executor/models/gemma4.py 模型执行器 modified 6.62

关键符号

forward _make_empty_intermediate_tensors

关键源码片段

vllm/model_executor/models/gemma4.py data-contract

修复 PP 在 Gemma4 中的核心模型文件,所有 9 行新增和 16 行删除均在此文件,涉及 IntermediateTensors 构造和 forward 路径。

# vllm/model_executor/models/gemma4.py# 改动 1: _make_empty_intermediate_tensors 中移除 residual 张量
# Gemma4 decoder layers 不产生独立 residual,无需同步
def _make_empty_intermediate_tensors(
    batch_size: int, dtype: torch.dtype, device: torch.device
) -> IntermediateTensors:
    tensors: dict[str, torch.Tensor] = {
        "hidden_states": torch.zeros(
            (batch_size, hidden_size), dtype=dtype, device=device
        ),
        # 之前存在 "residual" 项,已移除
    }
    if ple_dim and ple_dim > 0:
        tensors["per_layer_inputs"] = torch.zeros(
            (batch_size, num_layers, ple_dim), dtype=dtype, device=device
        )
    return IntermediateTensors(tensors)# 改动 2: forward 中非首 rank 分支调整
# 之前无条件读取 intermediate_tensors["residual"] 和 intermediate_tensors["per_layer_inputs"]
# 现在 per_layer_inputs 仅在显式传递时(非 None)才从 intermediate_tensors 读取
else:
    assert intermediate_tensors is not None
    hidden_states = intermediate_tensors["hidden_states"]
    # per_layer_inputs 参数在非首 rank 默认为 None,因此不读取,符合 Gemma4 设计
    if per_layer_inputs is not None:
        per_layer_inputs = intermediate_tensors["per_layer_inputs"]
residual = None # residual 总是 None# 改动 3: 返回 IntermediateTensors 时也不再包含 residual
# 只包含 hidden_states 和非 None 的 per_layer_inputs
if not get_pp_group().is_last_rank:
    tensors: dict[str, torch.Tensor] = {
        "hidden_states": hidden_states,
    }
    if per_layer_inputs is not None:
        tensors["per_layer_inputs"] = per_layer_inputs
    return IntermediateTensors(tensors)

评论区精华

per_layer_inputs 条件错误可能禁用 PLE 正确性

gemini-code-assist[bot] 指出 `if per_layer_inputs is not None:` 在非首 rank 始终为 False 会导致 PLE 功能被禁用,因为参数默认 None,无法从 intermediate_tensors 正确读取 per_layer_inputs。

结论:PR 作者和 reviewer 认为此行为符合 Gemma4 设计(per_layer_inputs 仅针对较小模型),并提供了 GSM8k 验证结果,未修改代码即合并。 · 已解决

风险与影响

风险点:改动较局部,但 PP 路径是核心逻辑。

  • 缺失测试覆盖:未增加 PP 特定测试,若未来对其他模型或 Gemma4 变种启用 PLE,则当前条件可能导致 per_layer_inputs 无法正确传递。
  • review 中未解决的疑虑:gemini-code-assist[bot] 指出的条件问题虽未被采纳,但若后续有模型需要跨 rank 传递 per_layer_inputs,当前代码存在隐患。
  • 兼容性风险:改动仅影响 Gemma4 模型,不影响其他模型。

影响范围:仅限 Gemma4 模型使用 Pipeline Parallelism 的场景。修复后 PP 能正确运行,消除可能因 residual 同步导致的错误。用户无需任何配置更改。
影响程度:中等,修复了正确性问题,但改动量小,且经过正确性验证。

缺少 PP 测试覆盖 per_layer_inputs 条件可能隐藏 bug

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论