执行摘要
本PR修复了FSDP训练中因动态批处理micro-batch计数不同步导致的CUDA死锁问题,通过向prepare_dynamic_batch函数传递dp_group参数确保所有数据并行rank同步,提升训练稳定性。
功能与动机
动机源于commit f5c34bb在verl/utils/seqlen_balancing.py中添加了dp_group is not None guard,导致FSDP actor调用prepare_dynamic_batch时未传递dp_group参数,跳过all_reduce操作。在动态批处理场景下,不同rank因序列长度分布差异计算不同micro-batch数量,进而引发FSDP collectives死锁。PR body指出:"不同序列长度分布导致不同rank计算不同微批次计数...FSDP collectives (AllGather/ReduceScatter)死锁"。
实现拆解
实现集中于单个文件verl/workers/actor/dp_actor.py:
- 在
compute_log_prob函数第468行附近,修改prepare_dynamic_batch调用,添加dp_group=torch.distributed.group.WORLD。
- 在
update_policy函数第560行附近,进行相同修改。
关键代码片段:
micro_batches, batch_idx_list = prepare_dynamic_batch(
data, max_token_len=max_token_len, dp_group=torch.distributed.group.WORLD
)
这确保所有rank通过all_reduce(MAX)同步micro-batch计数,防止死锁。
评论区精华
Review讨论简单,仅有gemini-code-assist[bot]的评论:
"The fix, which involves passing torch.distributed.group.WORLD as the dp_group to prepare_dynamic_batch in both compute_log_prob and update_policy, is a robust solution."
评论肯定了修复的有效性,无争议点。wuxibin89直接批准。
风险与影响
- 风险:变更引入
dp_group参数,使用WORLD在所有FSDP并行配置下正确,但需确保其他调用prepare_dynamic_batch的地方也正确处理该参数。PR提供了在2-node EKS集群的测试验证,但缺少单元测试覆盖。
- 影响:修复了FSDP动态批处理训练中的死锁bug,影响所有使用该配置的用户,提升训练可靠性和性能。测试显示修复后4/4运行成功,而死锁前100%再现。
关联脉络
本PR是#5451的FSDP对应版本,#5451修复了megatron workers中的相同bug。这表明动态批处理同步问题在多个并行策略中普遍存在,需要跨模块统一处理。近期历史PR如#5604涉及FSDP workers重构,但本PR专注于具体bugfix。
参与讨论