Prhub

#1788 [WIP] fix loss oom

THUDM/slime · 作者 lilei199908 · 合并时间 2026-04-04 23:41

分析状态 已生成
文件变更 2提交数 2 · 评论 4
代码增减 +241 / -62
bugfix performance configuration

执行摘要

修复损失计算内存溢出,优化 PPO 熵计算和 Megatron 损失路径。

PR body中展示了优化前后的内存使用对比图像,表明在损失计算过程中存在内存峰值问题。动机是修复OOM,确保训练过程更加稳定,特别是在使用大模型或复杂配置时。

建议工程师精读此PR,特别是熵梯度控制设计和checkpointing优化,这些是内存优化中的常见技巧。同时关注Copilot指出的潜在正确性问题,以确保变更不影响训练稳定性。

讨论亮点

Review中Copilot指出三个关键问题:1) 温度缩放缺失,导致rollout_temperature != 1.0时PPO行为不一致;2) 熵梯度处理在allgather_cp路径下可能无效,因为_allgather_cp_redistribute使用可微分操作;3) allgather_cp配置仅支持thd格式,但代码未做检查。Zhuzilin询问代码移动和logits.clone()的必要性。讨论未显示明确结论,但PR已获批准,可能问题被接受或后续处理。

实现拆解

主要改动在两个文件:1) slime/backends/megatron_utils/loss.py:重构了_allgather_cp_redistribute函数,引入_build_shifted_tokens函数以优化token构建逻辑,并调整get_log_probs_and_entropy函数一次性计算完整logits的log-probs和熵,减少重复计算。2) slime/utils/ppo_utils.py:修改calculate_log_probs_and_entropy函数,添加need_entropy_grad参数,当熵系数为零时使用torch.no_grad()避免梯度跟踪,降低内存开销;同时调整代码顺序和checkpointing为use_reentrant=False

文件 模块 状态 重要度
slime/backends/megatron_utils/loss.py Megatron 损失模块 modified 7.0
slime/utils/ppo_utils.py PPO 工具模块 modified 6.0

分析完成后,这里会展示 LLM 生成的相对完整源码片段和详细注释。

关键符号

_allgather_cp_redistribute _build_shifted_tokens get_log_probs_and_entropy calculate_log_probs_and_entropy

评论区精华

温度缩放缺失 正确性

Copilot 指出 `get_log_probs_and_entropy` 未应用 `rollout_temperature` 缩放,当 `rollout_temperature != 1.0` 时,可能导致 PPO ratios 和 KL 计算不一致,影响训练正确性。

结论:未明确解决,PR 已批准,可能问题被接受或忽略。 · unresolved

熵梯度处理可能无效 性能

Copilot 提到在 `allgather_cp` 启用时,`_allgather_cp_redistribute` 使用可微分操作,可能导致熵梯度优化失效,内存减少效果有限。

结论:未明确解决。 · unresolved

allgather_cp 配置检查 设计

Copilot 建议添加断言确保 `allgather_cp` 仅与 `thd` 格式配合使用,避免配置错误导致的运行时问题。

结论:未明确解决。 · unresolved

代码移动和 clone 必要性 设计

Zhuzilin 询问在 `calculate_log_probs_and_entropy` 中移动 log_prob 计算和 `logits.clone()` 是否必要,涉及内存和正确性权衡。

结论:未明确回答,但 PR 已批准,可能变更被接受。 · unresolved

风险与影响

风险包括:1) 温度缩放缺失可能改变PPO损失计算,影响训练收敛;2) 熵梯度优化在allgather_cp启用时可能不生效,内存减少有限;3) 配置不一致(如allgather_cpqkv_format不匹配)可能导致运行时错误;4) 核心路径变更引入回归风险,需测试验证正确性。

对用户:减少训练时内存使用,降低OOM风险,提升大模型训练体验。对系统:修改核心损失计算路径,影响所有使用Megatron和PPO的训练任务;性能优化可能提升整体训练效率。影响范围中高,需在集成后监控内存和收敛行为。

温度缩放缺失 熵梯度处理可能无效 配置不一致风险

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

执行摘要

  • 一句话:修复损失计算内存溢出,优化PPO熵计算和Megatron损失路径。
  • 推荐动作:建议工程师精读此PR,特别是熵梯度控制设计和checkpointing优化,这些是内存优化中的常见技巧。同时关注Copilot指出的潜在正确性问题,以确保变更不影响训练稳定性。

功能与动机

PR body中展示了优化前后的内存使用对比图像,表明在损失计算过程中存在内存峰值问题。动机是修复OOM,确保训练过程更加稳定,特别是在使用大模型或复杂配置时。

实现拆解

主要改动在两个文件:1) slime/backends/megatron_utils/loss.py:重构了_allgather_cp_redistribute函数,引入_build_shifted_tokens函数以优化token构建逻辑,并调整get_log_probs_and_entropy函数一次性计算完整logits的log-probs和熵,减少重复计算。2) slime/utils/ppo_utils.py:修改calculate_log_probs_and_entropy函数,添加need_entropy_grad参数,当熵系数为零时使用torch.no_grad()避免梯度跟踪,降低内存开销;同时调整代码顺序和checkpointing为use_reentrant=False

关键文件:

  • slime/backends/megatron_utils/loss.py(模块 Megatron损失模块): 包含损失计算的核心逻辑重构,优化内存使用和allgather操作。
  • slime/utils/ppo_utils.py(模块 PPO工具模块): 修改PPO log-probs和熵计算函数,添加熵梯度控制以减少内存开销。

关键符号:_allgather_cp_redistribute, _build_shifted_tokens, get_log_probs_and_entropy, calculate_log_probs_and_entropy

评论区精华

Review中Copilot指出三个关键问题:1) 温度缩放缺失,导致rollout_temperature != 1.0时PPO行为不一致;2) 熵梯度处理在allgather_cp路径下可能无效,因为_allgather_cp_redistribute使用可微分操作;3) allgather_cp配置仅支持thd格式,但代码未做检查。Zhuzilin询问代码移动和logits.clone()的必要性。讨论未显示明确结论,但PR已获批准,可能问题被接受或后续处理。

  • 温度缩放缺失 (correctness): 未明确解决,PR已批准,可能问题被接受或忽略。
  • 熵梯度处理可能无效 (performance): 未明确解决。
  • allgather_cp配置检查 (design): 未明确解决。
  • 代码移动和clone必要性 (design): 未明确回答,但PR已批准,可能变更被接受。

风险与影响

  • 风险:风险包括:1) 温度缩放缺失可能改变PPO损失计算,影响训练收敛;2) 熵梯度优化在allgather_cp启用时可能不生效,内存减少有限;3) 配置不一致(如allgather_cpqkv_format不匹配)可能导致运行时错误;4) 核心路径变更引入回归风险,需测试验证正确性。
  • 影响:对用户:减少训练时内存使用,降低OOM风险,提升大模型训练体验。对系统:修改核心损失计算路径,影响所有使用Megatron和PPO的训练任务;性能优化可能提升整体训练效率。影响范围中高,需在集成后监控内存和收敛行为。
  • 风险标记:温度缩放缺失, 熵梯度处理可能无效, 配置不一致风险

关联脉络

  • PR #1775 [Fix] Fix duplicate Megatron LR scheduler resume when optimizer state is not loaded: 同样涉及Megatron模块的性能bugfix,主题相关。
  • PR #1764 Add host memory metrics to available_memory function: 与内存监控和优化相关,主题相似。
  • PR #1769 Support FP8 conversion for Qwen3.5: 性能优化相关,都涉及训练效率改进。

参与讨论