# PR #5689 完整报告

- 仓库：`verl-project/verl`
- 标题：[fsdp] fix: avoid NestedTensor jagged dim ambiguity for 3D position_ids
- 合并时间：2026-03-23 17:42
- 原文链接：http://prhub.com.cn/verl-project/verl/pull/5689

---

## 执行摘要
本 PR 修复了在使用 FSDP 进行 VLM SFT 训练时，因 PyTorch NestedTensor 在处理相同序列长度样本时错误选择 jagged 维度而导致的间歇性形状不匹配错误。通过将 `torch.nested.as_nested_tensor` 替换为 `torch.nested.nested_tensor_from_jagged`，并明确指定偏移量，确保了 3D position_ids 等多维张量的正确维度结构，消除了训练失败风险，提升了系统稳定性。

## 功能与动机
PR 旨在解决错误“The size of tensor a (3) must match the size of tensor b (8)”，该错误在 VLM SFT 训练使用 FSDP 且 `DatasetPadMode.NO_PADDING` 时间歇性出现。根本原因是当微批次中所有样本具有相同 seq_len 时，`torch.nested.as_nested_tensor` 会错误地将第一个非批次维度（如 num_heads=4）作为 jagged 维度，而不是 seq_len 维度，导致 `.values()` 返回错误形状（例如 (8, 1, 100) 而非正确的 (4, 1, 200)）。Issue 评论中进一步建议使用 `nested_tensor_from_jagged` 来避免此歧义。

## 实现拆解
改动集中在 `verl/utils/dataset/dataset_utils.py` 文件的 `collate_variable_batch` 函数中：

- **关键逻辑变更**：对于维度大于等于 2 的张量（如 3D position_ids），使用 `torch.nested.nested_tensor_from_jagged` 构建 NestedTensor，通过以下步骤：
 1. 计算 values：沿最后一个维度（dim=-1）连接所有张量，例如从形状 (4, 100) 和 (4, 100) 得到 (4, 200)。
 2. 计算 offsets：基于各张量的 seq_len 生成偏移量，确保正确指定 jagged_dim。
 3. 调用 `nested_tensor_from_jagged(values, offsets=offsets)`。
- **保持兼容性**：对于低维张量（dim < 2），保留原 `torch.nested.as_nested_tensor` 调用，以避免不必要开销。

commit 历史显示从最初在 FSDPEngineWithLMHead 中使用 `unbind()+cat()` 修复，演进到在 collator 源头实施更彻底的解决方案。

## 评论区精华
- **设计决策优化**：wuxibin89 在 issue 评论中指出：“We should use `nested_tensor_from_jagged` instead of `nested_tensor` to process 3d position_ids.” 并提供代码示例证明正确性。这影响了 PR 最终实现，从局部修复调整到 collator 级别。
- **修复认可**：gemini-code-assist[bot] 评论：“The pull request effectively addresses the NestedTensor jagged dimension ambiguity... The added comment clearly explains the root cause and the chosen solution, which is beneficial for future maintainability.” 强调了代码清晰度和维护价值。

## 风险与影响
- **风险**：
 - 回归风险：如果 offsets 计算错误（如长度列表处理不当），可能引入新 bug。
 - 性能影响：`nested_tensor_from_jagged` 可能有额外开销，但鉴于在数据预处理阶段且仅针对多维张量，影响可忽略。
 - 测试覆盖：PR 未添加单元测试，依赖现有 CI 和手动验证，可能遗漏其他边缘情况。
- **影响**：
 - 用户影响：解决了使用 FSDP 进行 VLM SFT 训练时的间歇性失败，提升训练可靠性。
 - 系统影响：修复了 `SFTTensorCollator` 中的潜在缺陷，确保 NestedTensor 构建正确，避免后续处理错误。

## 关联脉络
- **相关 PR**：PR #5717 也修改了 `verl/utils/dataset/dataset_utils.py` 文件，表明该模块在近期有持续调整，可能涉及其他修复或优化。
- **演进方向**：本 PR 从针对 FSDPEngine 的临时修复演进到在 dataset collator 中解决根因，反映了对问题定位的深化和模块化设计改进。