执行摘要
修复 XPU 平台在 torch.compile 模式下 all_reduce 返回全零的精度问题。
PR body 明确指出:'XPU all_reduce returns all-zeros in compile mode, dist.all_reduce is an in-place operation. When traced by inductor, the original input tensor may be optimized away since the compiler does not see a new tensor being produced, causing the output to be all-zeros.' 这导致在 torch.compile 模式下,模型推理精度完全失效(如测试中 gsm8k 的 exact_match 从 0.52 降为 0.0)。
该 PR 值得精读,因为它揭示了 torch.compile 在优化 in-place 操作时可能导致的隐蔽精度问题,并展示了通过 out-of-place 操作规避编译器优化的实用技巧。关注点:条件克隆的逻辑设计(torch.compiler.is_compiling())和类型提示的添加如何提升代码健壮性。
review 中主要讨论点:
- 正确性修复:gemini-code-assist[bot] 指出 'The all_reduce implementation is now out-of-place, which correctly addresses the torch.compile issue where in-place mutations on inputs can lead to incorrect optimizations (like returning all-zeros).' 确认了修复方案的有效性。
- 代码风格一致性:gemini-code-assist[bot] 建议 'for consistency with other methods in this class (e.g., reduce_scatter at line 54) and the base class DeviceCommunicatorBase, the method signature should ideally include a type hint for the input_ parameter.' 作者 chaojun-zhang 随后更新代码添加了类型提示。
- 性能影响评估:jikunshang 在合并评论中表示 'perf impact is very limited. LGTM.' 认可了性能影响可接受。
参与讨论