执行摘要
本PR将数据并行工作量平衡功能从强化学习迁移到监督微调训练器,通过添加配置开关和集成负载平衡算法,优化批次序列长度分布,旨在减少管道并行中的空闲时间,提升训练效率。
功能与动机
根据PR body,动机是迁移DP workload balancing feature from RL to SFT,引用自PR #3605。目的是在SFT训练中实现类似的负载平衡,解决数据并行单元间工作负载不均的问题,以优化训练性能。关键表述来自PR body:"Migrate the DP workload balancing feature from RL to SFT"。
实现拆解
- 配置文件变更:在
verl/trainer/config/sft_trainer_engine.yaml中添加balance_batch: True配置项,默认启用负载平衡。
- 训练器逻辑变更:在
verl/trainer/sft_trainer_ray.py的fit函数中,添加以下代码块:
python
if self.config.trainer.balance_batch:
global_seqlen_lst = torch.Tensor([item.size()[0] for item in data["input_ids"]])
global_seqlen_lst = calculate_workload(global_seqlen_lst)
dp_size = max(self.training_client._query_dispatch_info("train")) + 1
global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst, k_partitions=dp_size, equal_size=True)
for idx, partition in enumerate(global_partition_lst):
partition.sort(key=lambda x: (global_seqlen_lst[x], x))
ordered_partition = partition[::2] + partition[1::2][::-1]
global_partition_lst[idx] = ordered_partition
global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
data = tu.index_select_tensor_dict(data, global_idx)
关键步骤包括序列长度计算、分区、排序以减少气泡,最终重排序数据张量。
评论区精华
review讨论中,gemini-code-assist[bot]指出两个关键问题:
"1. Inefficiency: Line 312 recalculates sequence lengths by iterating over data["input_ids"], which is inefficient for nested tensors. The sequence lengths are already computed and available in the batch_seqlens variable from line 305. Reusing this variable would be much more performant.
2. Bug: Line 317 passes global_seqlen_lst (a torch.Tensor) to get_seqlen_balanced_partitions, which expects a list[int]."
wuxibin89建议:
"Please reuse index_select_tensor_dict"
作者arvyanh回复已修复,问题得到解决,体现代码优化和正确性改进。
风险与影响
- 技术风险:
- 回归风险:新逻辑可能影响训练稳定性,尤其是序列长度计算和类型转换环节,如未正确处理torch.Tensor到list的转换。
- 性能风险:如果未优化重复计算(如使用batch_seqlens),可能导致训练速度下降。
- 兼容性风险:
balance_batch默认启用,可能与现有SFT配置或工作流冲突,需用户验证。
- 影响分析:
- 对用户:通过配置可启用负载平衡,潜在提升训练效率,但需测试验证效果。
- 对系统:优化数据并行工作负载,减少管道并行气泡,可能提高整体吞吐量。
- 对团队:复用RL功能,促进代码共享,但需确保SFT场景下的集成稳定性。
关联脉络
本PR直接关联历史PR #3605,迁移其DP workload balancing feature从RL到SFT,显示功能跨场景复用的趋势。从近期历史PR分析看,类似性能优化PR如#5057(动态CP)也涉及megatron和perf标签,表明仓库在持续优化训练效率。整体脉络指向在多种训练场景中集成负载平衡和并行优化技术,以提升系统性能。
参与讨论