Prhub

#40562 [Bugfix][Torch 2.12] Fix batch_invariant test with allow_override for torch 2.12 upgrade

原始 PR 作者 Lucaskabela 合并时间 2026-04-23 04:48 文件变更 1 提交数 3 评论 4 代码增减 +6 / -2

执行摘要

修复 Torch 2.12 下 bmm 注册冲突

Torch 2.12 在 https://github.com/pytorch/pytorch/pull/179082 中引入了内置 Triton aten::bmm 内核,导致 vLLM 的 batch_invariant 模块在初始化时抛出 RuntimeError: already a kernel registered from python,阻塞 Torch 2.12 升级。

值得阅读,了解 Torch 与下游框架在 dispatcher 层面的交互。

讨论亮点

gemini-code-assist[bot] 建议将同样的 allow_override=True 应用到其他操作符(如 softmax、mean 等),以防未来出现类似冲突;同时建议恢复“monkeypatch”相关的注释。但 PR 作者未采纳,其他审核人(yewentao256, zou3519)均批准。

实现拆解

  1. 定位问题:在 vllm/model_executor/layers/batch_invariant.pyenable_batch_invariant_mode() 函数中,_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA") 与 Torch 2.12 新增的内核注册冲突。
  2. 解决方案:为 _batch_invariant_LIB.impl 调用添加 allow_override=True 参数,显式允许替换已有内核。该参数自 Torch 2.8 起可用,无兼容性问题。同时更新注释说明原因。
  3. 测试验证:运行 tests/v1/determinism/test_batch_invariance.py 全部 9 个测试通过。
文件 模块 状态 重要度
vllm/model_executor/layers/batch_invariant.py 矩阵层 modified 5.8

关键符号

enable_batch_invariant_mode

关键源码片段

vllm/model_executor/layers/batch_invariant.py data-contract

核心变更文件,修改了 `enable_batch_invariant_mode()` 中 `_batch_invariant_LIB.impl` 调用,添加 `allow_override=True`。

# vllm/model_executor/layers/batch_invariant.py ( 第 966-973 行 )
# 注释:torch 2.12+ 注册了内置 Triton bmm 内核
# (torch._native.ops.bmm_outer_product),
# 因此需要使用 allow_override 来在 dispatcher 层替换它。
_batch_invariant_LIB.impl(
    "aten::bmm", bmm_batch_invariant, "CUDA", allow_override=True
)
_original_torch_bmm = torch.bmm
torch.bmm = bmm_batch_invariant # 保留直接 monkeypatch 作为回退

评论区精华

是否需要对所有 impl 调用添加 allow_override 设计

gemini-code-assist[bot] 建议为 softmax、mean 等所有 impl 调用添加 allow_override,以防 future 冲突。

结论:PR 仅修复 bmm,其他操作符未报告冲突,故未采纳。 · 已解决

风险与影响

低风险。变更仅限于 batch_invariant.py 中的一行 impl 调用,使用 allow_override=True 是 Torch 官方推荐的 API。测试已验证。

影响范围限于使用 Torch 2.12+ 且启用 batch invariance 功能的场景。修复后,vLLM 可顺利在 Torch 2.12 上初始化。

核心路径变更 缺少测试覆盖

关联 Issue

#180905 [vllm] [2.12 regression] torch.library.Library.impl("aten::bmm", ..., "CUDA") now fails with "already a kernel registered from python"

完整报告

参与讨论