Prhub

#40973 [Bugfix][CPU] Backport PT cpp codegen indirect_assert scalar-mask fix

原始 PR 作者 amd-lalithnc 合并时间 2026-04-29 22:21 文件变更 1 提交数 2 评论 5 代码增减 +117 / -0

执行摘要

Backport PT fix 修复 CPU 端 torch.compile 编译错误

PR #40972 报告 Qwen3-VL 在 PyTorch 2.11 CPU 后端编译时抛出 CppCompileError,根源在于 Inductor 生成的 C++ 代码中对标量 mask 调用了不正确的构造函数 VecMask(scalar),而正确用法应为 VecMask::from(scalar)。该上游修复已在 PyTorch 主线合并(#178148),但需等到 2.12 才正式发布。此 PR 在 vLLM 侧以 monkey-patch 形式提前 Backport 该修复,并遵循既有的 env_override 模式。

值得阅读,特别是了解 vLLM 如何通过 env_override 模式紧急 backport 上游修复。延迟加载 import hook 的实现技巧也值得学习。建议在升级 PyTorch 最低版本到 2.12 前保留此补丁。

讨论亮点

版本范围讨论:gemini-code-assist[bot] 建议将版本检查中的 "2.11.0" 改为 "2.11.0.dev" 以覆盖开发版,但维护者未采纳该建议,最终合并时仍使用 "2.11.0"

多批准:tlrmchlsmth、fadara01(上游修复作者)、bigPYJ1151 分别给予批准。

实现拆解

  1. vllm/env_override.py 中新增 _apply_cpp_indirect_assert_patch 函数,该函数获取 CppVecKernel 类,将其 indirect_assert 方法替换为修复版本 patched_indirect_assert。修复关键在于当 mask 为标量时使用 VecMask::from(mask) 替代 VecMask(mask)

  2. 新增 _patch_cpp_indirect_assert_if_needed 函数,通过版本检查将补丁限制在 PyTorch >=2.11.0 且 <2.12.0.dev。为避免在 vLLM 初始化时过早导入 torch._inductor.codegen.cpp,引入自定义 MetaPathFinder 子类 _CppCodegenPatchFinder,利用 find_spec 钩子在模块首次加载时执行补丁。

  3. 在文件末尾调用 _patch_cpp_indirect_assert_if_needed() 自动注册 import hook 并触发补丁(在模块加载后)。

  4. 补丁使用标志位 _vllm_indirect_assert_patched 保证幂等性,不会重复应用。

  5. 本次变更未新增或修改测试文件,但作者提供了详细的端到端验证结果。

文件 模块 状态 重要度
vllm/env_override.py 编译补丁 modified 8.59

关键符号

patched_indirect_assert _apply_cpp_indirect_assert_patch _patch_cpp_indirect_assert_if_needed _CppCodegenPatchFinder.find_spec _CppCodegenPatchFinder._exec_then_patch

关键源码片段

vllm/env_override.py core-logic

单文件改动,新增补丁函数和延迟加载机制,是核心变更实现

def _apply_cpp_indirect_assert_patch():
    """替换 CppVecKernel.indirect_assert 以使用 VecMask::from 修复标量 mask 构造错误。"""
    from torch._inductor.codegen.cpp import CppVecKernel
​
    # 幂等性保护
    if getattr(CppVecKernel, "_vllm_indirect_assert_patched", False):
        return
​
    from torch._inductor.codegen.cpp import CppCSEVariable, cexpr_index
​
    def patched_indirect_assert(self, var, lower, upper, mask=None):
        """修复版的 indirect_assert,对标量 mask 使用 ::from。"""
        assert isinstance(var, CppCSEVariable)
        assert var.dtype is not None
​
        # 非向量情况:直接父类方法
        if not var.is_vec:
            if isinstance(mask, CppCSEVariable) and mask.is_vec:
                mask = f"({mask}).all_masked()"
            return super(CppVecKernel, self).indirect_assert(var, lower, upper, mask)
​
        lower_scalar = lower
        upper_scalar = upper
        if lower:
            lower = f"{self._get_vec_type(var.dtype)}({lower})"
        if upper:
            upper = f"{self._get_vec_type(var.dtype)}({upper})"
        if lower and upper:
            cond = f"({lower} <= {var}) & ({var} < {upper})"
            cond_print = f"{lower_scalar} <= {var} < {upper_scalar}"
        elif lower:
            cond = f"{lower} <= {var}"
            cond_print = f"{lower_scalar} <= {var}"
        else:
            assert upper
            cond = f"{var} < {upper}"
            cond_print = f"{var} < {upper_scalar}"
​
        cond = f"{self._get_mask_type(var.dtype)}({cond})"
        if mask:
            if not mask.is_vec:
                # 关键修复:使用 ::from 而非直接构造
                mask = f"{self._get_mask_type(var.dtype)}::from({mask})"
            cond = f"({cond}) | ~({mask})"
​
        if self.tail_size:
            cond = (
                f"{self._get_mask_type(var.dtype)}::set("
                f"{self._get_mask_type(var.dtype)}::from(1)"
                f", ({cond}), {cexpr_index(self.tail_size)})"
            )
​
        cond = f"({cond}).all_masked()"
        return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
​
    CppVecKernel.indirect_assert = patched_indirect_assert
    CppVecKernel._vllm_indirect_assert_patched = True

评论区精华

版本范围检查应考虑开发版 正确性

gemini-code-assist[bot] 提出版本检查 `is_torch_equal_or_newer("2.11.0")` 会遗漏开发版(如 2.11.0.dev),建议改为 `"2.11.0.dev"` 以覆盖更多用户。

结论:PR 维护者未采纳该建议,最终合并时仍使用 `"2.11.0"`。 · not addressed

风险与影响

  • 版本兼容风险:补丁仅针对 PyTorch 2.11.x 稳定版,若用户使用 2.11.0.dev 开发版可能遗漏。当未来最小 PyTorch 版本升级到 2.12 后,需手动移除该补丁,否则可能产生冲突。
  • 延迟加载复杂性:import hook 依赖 torch._inductor.codegen.cpp 模块路径,若路径变化可能导致补丁失效。该模式已在 #38205 中初步验证,但仍有潜在风险。
  • 测试覆盖不足:本次没有新增单元测试,长期维护依赖人工回归和 CI 覆盖。
  • 影响范围:仅 CPU 后端且仅涉及使用间接索引的模型,其他路径不受影响。
  • 用户影响:修复了 Qwen3-VL 等模型在 CPU 端使用 torch.compile 时的编译错误,使模型可正常运行。其他模型无变化。
  • 系统影响:在 vLLM 初始化时引入一个版本判断和 import hook,开销极小。
  • 团队影响:需跟踪 PyTorch 版本升级,及时移除补丁。补丁模式为将来类似 backport 提供了参考。
版本兼容风险 延迟加载依赖变形 缺少测试覆盖

关联 Issue

#178148 [CPU][Inductor] Use VecMask::from for scalar masks in codegen
#38205 [ZenCPU] Make PT Backport Patch Accessible to vLLM
#40972 [Bug]: [CPU] Qwen3-VL fails at torch.compile warmup on PT 2.11 with CppCompileError: VecMask<...>::VecMask(int&)

完整报告

参与讨论