执行摘要
优化采样器在批大小变化时的重新编译行为,提升推理性能。
根据 PR body,在测试 RL 用例时,batch size = 1 会导致 sampler 重新编译,因为 'pytorch with dynamic defaults to specializing on 0/1 values',从而增加运行时开销。目标是消除此开销以提升性能。
建议工程师精读此 PR,特别是 mark_unbacked 的使用和动态形状处理策略,对于优化 PyTorch 编译性能有参考价值。关注讨论中的未决建议,如未来集成 min/max 参数,并考虑在其他编译函数中应用类似技巧。
Review 中核心讨论包括:gemini-code-assist[bot] 建议在 batched_count_greater_than 中添加 x 和 values batch 维度相等的检查以优化编译图符号统一,但未被采纳;同一 bot 还建议将 mark_unbacked 调用移到 gather_logprobs 开头以避免早期专门化,但作者 Lucaskabela 反驳称 'inductor can't codegen with a fully unbacked batch dim',因此保持原位置;laithsakka 建议使用 mark_unbacked 的 min/max 参数(适用于 torch >=2.12)作为未来改进,添加 TODO 或条件。最终 PR 被批准,部分建议留作未解决。
参与讨论