Prhub

#1822 Revert no_grad for entropy to prevent comm stuck in dsa

THUDM/slime · 作者 zhuzilin · 合并时间 2026-04-09 19:20

分析状态 已生成
文件变更 2提交数 1 · 评论 0
代码增减 +20 / -19
bugfix configuration performance

执行摘要

移除熵计算中的 no_grad 上下文,修复 DSA 模式下通信卡死问题。

根据PR标题和提交信息,变更动机是修复DSA(Distributed Shared Architecture)模式下通信卡死问题。具体表现为熵计算使用torch.no_grad()时,在分布式环境中可能导致通信操作无法正常完成。PR body未提供详细描述,但从代码变更可推断需要确保熵计算张量具有梯度信息以维持分布式通信的连续性。

建议技术管理者和核心工程师精读此PR,重点关注:

  1. 熵计算梯度保留的设计决策,理解DSA通信机制的特殊要求。
  2. 分布式张量重建逻辑中对None值的处理方式,确保边缘场景覆盖。
  3. 结合近期PR #1788(修复loss oom)和 #1762(修复grad_norm初始化)一起分析,这些PR都涉及损失计算和梯度处理的底层优化。
讨论亮点

该PR没有review评论,属于直接合并的修复。从代码变更看,核心决策是彻底移除熵计算中的梯度控制逻辑,统一使用可计算梯度的张量,这可能是基于DSA环境下通信机制的特殊要求。

实现拆解

实现方案分为两个关键文件修改:

  1. slime/backends/megatron_utils/loss.py:修改_allgather_cp_redistribute函数,增加对None值的跳过逻辑,并统一使用参考张量的dtype/device创建零张量,避免因value为None导致的属性访问错误。同时移除need_entropy_grad参数及相关逻辑。
  2. slime/utils/ppo_utils.py:重构calculate_log_probs_and_entropy函数,完全移除need_entropy_grad参数和torch.no_grad()上下文管理,确保熵计算始终使用可计算梯度的logits.clone()输入,避免梯度信息丢失。
文件 模块 状态 重要度
slime/backends/megatron_utils/loss.py megatron_utils modified 7.0
slime/utils/ppo_utils.py ppo_utils modified 8.0

分析完成后,这里会展示 LLM 生成的相对完整源码片段和详细注释。

关键符号

_allgather_cp_redistribute calculate_log_probs_and_entropy compute_entropy_from_logits

评论区精华

熵计算梯度保留的必要性 正确性

PR 没有实际 review 讨论,但从代码变更可推断核心争议点在于是否应该保留熵计算的梯度信息。原实现通过 need_entropy_grad 参数和 no_grad 上下文控制梯度计算,新实现统一使用可计算梯度的张量。

结论:决定完全移除梯度控制,确保 DSA 环境下通信连续性。 · 已解决

分布式张量重建中的 None 值处理 正确性

loss.py 中增加了对全 None 值列表的跳过逻辑,并统一使用参考张量的 dtype/device 创建零张量,避免属性访问错误。

结论:实现更健壮的 None 值处理机制,确保分布式通信的鲁棒性。 · 已解决

风险与影响

技术风险包括:

  1. 性能影响:移除no_grad可能增加显存占用和计算开销,因为熵计算现在会保留梯度信息。
  2. 兼容性风险:变更可能影响非DSA环境下的训练行为,特别是当entropy_coef=0时,原本不需要梯度计算,现在可能产生不必要的开销。
  3. 逻辑一致性:loss.py中增加对None值的跳过逻辑,需确保在所有分布式场景下都能正确处理None值,避免遗漏边缘情况。
    关键风险点在于熵计算梯度保留对整体训练稳定性的影响,需验证是否会导致梯度爆炸或内存溢出。

影响范围:

  1. 对用户:修复DSA环境下的训练卡死问题,提升分布式训练的稳定性,但可能略微增加显存使用。
  2. 对系统:影响所有使用Megatron损失计算和PPO熵计算的训练流程,特别是涉及分布式通信的场景。
  3. 对团队:变更涉及核心训练逻辑,需要团队关注后续性能监控和回归测试。
    影响程度中等,主要针对特定架构(DSA)的问题修复,但改动触及分布式通信和梯度计算的基础层。
核心路径变更 梯度计算调整 分布式通信依赖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

执行摘要

  • 一句话:移除熵计算中的no_grad上下文,修复DSA模式下通信卡死问题。
  • 推荐动作:建议技术管理者和核心工程师精读此PR,重点关注:
    1. 熵计算梯度保留的设计决策,理解DSA通信机制的特殊要求。
    2. 分布式张量重建逻辑中对None值的处理方式,确保边缘场景覆盖。
    3. 结合近期PR #1788(修复loss oom)和 #1762(修复grad_norm初始化)一起分析,这些PR都涉及损失计算和梯度处理的底层优化。

功能与动机

根据PR标题和提交信息,变更动机是修复DSA(Distributed Shared Architecture)模式下通信卡死问题。具体表现为熵计算使用torch.no_grad()时,在分布式环境中可能导致通信操作无法正常完成。PR body未提供详细描述,但从代码变更可推断需要确保熵计算张量具有梯度信息以维持分布式通信的连续性。

实现拆解

实现方案分为两个关键文件修改:

  1. slime/backends/megatron_utils/loss.py:修改_allgather_cp_redistribute函数,增加对None值的跳过逻辑,并统一使用参考张量的dtype/device创建零张量,避免因value为None导致的属性访问错误。同时移除need_entropy_grad参数及相关逻辑。
  2. slime/utils/ppo_utils.py:重构calculate_log_probs_and_entropy函数,完全移除need_entropy_grad参数和torch.no_grad()上下文管理,确保熵计算始终使用可计算梯度的logits.clone()输入,避免梯度信息丢失。

关键文件:

  • slime/backends/megatron_utils/loss.py(模块 megatron_utils): 修改了分布式张量重建的核心函数_allgather_cp_redistribute,增加None值跳过逻辑并统一dtype/device引用,直接影响损失计算的通信稳定性。
  • slime/utils/ppo_utils.py(模块 ppo_utils): 重构了calculate_log_probs_and_entropy函数,彻底移除梯度控制参数和no_grad上下文,这是修复通信卡死的核心变更点。

关键符号:_allgather_cp_redistribute, calculate_log_probs_and_entropy, compute_entropy_from_logits

评论区精华

该PR没有review评论,属于直接合并的修复。从代码变更看,核心决策是彻底移除熵计算中的梯度控制逻辑,统一使用可计算梯度的张量,这可能是基于DSA环境下通信机制的特殊要求。

  • 熵计算梯度保留的必要性 (correctness): 决定完全移除梯度控制,确保DSA环境下通信连续性。
  • 分布式张量重建中的None值处理 (correctness): 实现更健壮的None值处理机制,确保分布式通信的鲁棒性。

风险与影响

  • 风险:技术风险包括:

    1. 性能影响:移除no_grad可能增加显存占用和计算开销,因为熵计算现在会保留梯度信息。
    2. 兼容性风险:变更可能影响非DSA环境下的训练行为,特别是当entropy_coef=0时,原本不需要梯度计算,现在可能产生不必要的开销。
    3. 逻辑一致性:loss.py中增加对None值的跳过逻辑,需确保在所有分布式场景下都能正确处理None值,避免遗漏边缘情况。
      关键风险点在于熵计算梯度保留对整体训练稳定性的影响,需验证是否会导致梯度爆炸或内存溢出。
  • 影响:影响范围:

    1. 对用户:修复DSA环境下的训练卡死问题,提升分布式训练的稳定性,但可能略微增加显存使用。
    2. 对系统:影响所有使用Megatron损失计算和PPO熵计算的训练流程,特别是涉及分布式通信的场景。
    3. 对团队:变更涉及核心训练逻辑,需要团队关注后续性能监控和回归测试。
      影响程度中等,主要针对特定架构(DSA)的问题修复,但改动触及分布式通信和梯度计算的基础层。
  • 风险标记:核心路径变更, 梯度计算调整, 分布式通信依赖

关联脉络

  • PR #1788 [WIP] fix loss oom: 同样修改了slime/backends/megatron_utils/loss.py文件,优化损失计算内存使用,与本PR的loss修改有直接关联。
  • PR #1762 [Fix] Initialize grad_norm before found_inf skip path: 涉及Megatron训练中的梯度处理问题修复,与本PR的梯度计算调整属于同一技术领域。
  • PR #1807 sync from internal: 同样修改了megatron_utils模块,优化多模态训练兼容性,显示该模块近期活跃度较高。

参与讨论