执行摘要
本PR修复了在使用FSDP进行VLM SFT训练时,因PyTorch NestedTensor在处理相同序列长度样本时错误选择jagged维度而导致的间歇性形状不匹配错误。通过将torch.nested.as_nested_tensor替换为torch.nested.nested_tensor_from_jagged,并明确指定偏移量,确保了3D position_ids等多维张量的正确维度结构,消除了训练失败风险,提升了系统稳定性。
功能与动机
PR旨在解决错误“The size of tensor a (3) must match the size of tensor b (8)”,该错误在VLM SFT训练使用FSDP且DatasetPadMode.NO_PADDING时间歇性出现。根本原因是当微批次中所有样本具有相同seq_len时,torch.nested.as_nested_tensor会错误地将第一个非批次维度(如num_heads=4)作为jagged维度,而不是seq_len维度,导致.values()返回错误形状(例如(8, 1, 100)而非正确的(4, 1, 200))。Issue评论中进一步建议使用nested_tensor_from_jagged来避免此歧义。
实现拆解
改动集中在verl/utils/dataset/dataset_utils.py文件的collate_variable_batch函数中:
- 关键逻辑变更:对于维度大于等于2的张量(如3D position_ids),使用
torch.nested.nested_tensor_from_jagged构建NestedTensor,通过以下步骤:
- 计算values:沿最后一个维度(dim=-1)连接所有张量,例如从形状(4, 100)和(4, 100)得到(4, 200)。
- 计算offsets:基于各张量的seq_len生成偏移量,确保正确指定jagged_dim。
- 调用
nested_tensor_from_jagged(values, offsets=offsets)。
- 保持兼容性:对于低维张量(dim < 2),保留原
torch.nested.as_nested_tensor调用,以避免不必要开销。
commit历史显示从最初在FSDPEngineWithLMHead中使用unbind()+cat()修复,演进到在collator源头实施更彻底的解决方案。
评论区精华
- 设计决策优化:wuxibin89在issue评论中指出:“We should use
nested_tensor_from_jagged instead of nested_tensor to process 3d position_ids.” 并提供代码示例证明正确性。这影响了PR最终实现,从局部修复调整到collator级别。
- 修复认可:gemini-code-assist[bot]评论:“The pull request effectively addresses the NestedTensor jagged dimension ambiguity... The added comment clearly explains the root cause and the chosen solution, which is beneficial for future maintainability.” 强调了代码清晰度和维护价值。
风险与影响
- 风险:
- 回归风险:如果offsets计算错误(如长度列表处理不当),可能引入新bug。
- 性能影响:
nested_tensor_from_jagged可能有额外开销,但鉴于在数据预处理阶段且仅针对多维张量,影响可忽略。
- 测试覆盖:PR未添加单元测试,依赖现有CI和手动验证,可能遗漏其他边缘情况。
- 影响:
- 用户影响:解决了使用FSDP进行VLM SFT训练时的间歇性失败,提升训练可靠性。
- 系统影响:修复了
SFTTensorCollator中的潜在缺陷,确保NestedTensor构建正确,避免后续处理错误。
关联脉络
- 相关PR:PR #5717也修改了
verl/utils/dataset/dataset_utils.py文件,表明该模块在近期有持续调整,可能涉及其他修复或优化。
- 演进方向:本PR从针对FSDPEngine的临时修复演进到在dataset collator中解决根因,反映了对问题定位的深化和模块化设计改进。
参与讨论