Prhub

#42129 [Inductor] Fast-path Inductor fallback for vllm::*/vllm_aiter::* custom ops

原始 PR 作者 okorzh-amd 合并时间 2026-06-04 13:03 文件变更 2 提交数 6 评论 9 代码增减 +348 / -0

执行摘要

为 vllm 自定义 op 添加 Inductor 快速回退路径,防止编译挂起

当 Inductor 遇到没有注册 lowering 的自定义 op(如 vllm::all_reduce)时,除非该 op 在 FALLBACK_ALLOW_LIST 中,Inductor 会执行慢路径日志记录,其中 operator_str 递归地字符串化每一个输入 TensorBox。对于深层 MoE/TP 计算图(如 Kimi-K2.6 在 TP=8 时),IR 溯源树可达数百层,字符串化每个 op 需要数分钟 CPU 时间,导致 torch.compile 实际上挂起。此 PR 旨在通过修补 FALLBACK_ALLOW_LIST 来跳过此慢路径。

此 PR 值得精读,尤其是对使用 torch.compile 的团队。设计模式(代理类包装全局 set)是低侵入性修补的范例。建议在后续 PyTorch 升级后测试兼容性。

讨论亮点
  • 位置建议:reviewer zou3519 建议将修补代码从 parallel_state.py 移到 env_override.py,因为那里已存在其他环境修补。作者随即采纳。
  • 测试覆盖:zou3519 询问是否可能添加测试。作者回复已添加单元测试,并补充了端到端测试结果(编译时间从 3h+ 缩减到 5-7min)。
  • 替代方案:zou3519 表示想研究是否有更好的方式告诉 Inductor 回退,但未提出具体方案。PR 最终仍采用当前代理包装方案。
  • 测试位置:zou3519 建议测试文件应放在 tests/compile 下而非根目录。作者照做。

实现拆解

  1. vllm/env_override.py 中定义 _VllmFallbackAllowList 代理类,包装原始的 OrderedSet。该代理的 __contains__ 对字符串前缀为 vllm::vllm_aiter:: 返回 True,其他操作委派给内部集合;通过 __getattr__ 转发其他属性访问,保持与 Inductor 代码的兼容。
  2. 实现 _patch_inductor_fallback_allow_list() 函数,获取 torch._inductor.lowering.FALLBACK_ALLOW_LIST 并用 _VllmFallbackAllowList 包装。如果 torch._inductor.graph 模块已加载,则同时更新其本地绑定,确保 GraphLowering.call_function 使用包装后的集合。该函数是幂等的(通过检查 _vllm_patched 标志)。
  3. 在文件末尾调用 _patch_inductor_fallback_allow_list() 自动应用修补,确保在第一次编译前生效。
  4. 配套测试文件 tests/compile/test_inductor_fallback_allow_list_patch.py 覆盖代理的成员检查、委派、迭代、__getattr__ 转发等行为,以及修补应用到 lowering 和 graph 模块的正确性和幂等性。测试结果显示所有单元测试通过,且端到端编译时间从数小时降至 5-7 分钟。
文件 模块 状态 重要度
vllm/env_override.py 环境覆盖 modified 8.33
tests/compile/test_inductor_fallback_allow_list_patch.py 编译测试 added 7.48

关键符号

_patch_inductor_fallback_allow_list _VllmFallbackAllowList.__contains__ _VllmFallbackAllowList.__init__

关键源码片段

vllm/env_override.py dependency-wiring

核心修补实现,包括 _VllmFallbackAllowList 代理类和 _patch_inductor_fallback_allow_list 函数,负责包装 Inductor 的 FALLBACK_ALLOW_LIST 以自动允许 vllm 自定义操作。

# 代理类,包装 Inductor 的 FALLBACK_ALLOW_LIST,自动允许 vllm:: 和 vllm_aiter:: 命名空间class _VllmFallbackAllowList:
    """Membership proxy that auto-allows vllm::*/vllm_aiter::* base_names."""
​
    _vllm_patched = True # 标记,用于幂等检查
​
    def __init__(self, inner):
        self._inner = inner # 原始 OrderedSet
​
    def __contains__(self, item):
        # 对字符串且以 vllm:: 或 vllm_aiter:: 开头则直接允许
        if isinstance(item, str) and item.startswith(("vllm::", "vllm_aiter::")):
            return True
        # 其他情况委派给内部集合
        return item in self._inner
​
    def add(self, item):
        self._inner.add(item)
​
    def discard(self, item):
        self._inner.discard(item)
​
    def __iter__(self):
        return iter(self._inner)
​
    def __len__(self):
        return len(self._inner)
​
    def __repr__(self):
        return f"_VllmFallbackAllowList({self._inner!r})"
​
    def __getattr__(self, name):
        # 任何其他属性访问直接委派给内部集合
        return getattr(self._inner, name)
​
​
def _patch_inductor_fallback_allow_list() -> None:
    """Wrap torch._inductor.lowering.FALLBACK_ALLOW_LIST 为 _VllmFallbackAllowList."""
    try:
        from torch._inductor import lowering as _lowering
    except ImportError:
        return
​
    base = getattr(_lowering, "FALLBACK_ALLOW_LIST", None)
    if base is None or getattr(base, "_vllm_patched", False):
        return
​
    _lowering.FALLBACK_ALLOW_LIST = _VllmFallbackAllowList(base)
​
    # 同步更新 graph 模块的本地绑定,确保 GraphLowering.call_function 使用包装后的集合
    try:
        from torch._inductor import graph as _graph
        if hasattr(_graph, "FALLBACK_ALLOW_LIST"):
            _graph.FALLBACK_ALLOW_LIST = _lowering.FALLBACK_ALLOW_LIST
    except ImportError:
        pass
tests/compile/test_inductor_fallback_allow_list_patch.py test-coverage

单元测试验证代理语义和补丁应用,包括命名空间自动允许、回退行为、幂等性等。

# 测试 _VllmFallbackAllowList 代理语义class TestVllmFallbackAllowListProxy:
    """Unit tests for the membership-proxy semantics."""
​
    def test_vllm_namespace_auto_allowed(self):
        proxy = _VllmFallbackAllowList(set())
        # vllm:: 前缀操作应始终被视为允许
        assert "vllm::all_reduce" in proxy
        assert "vllm::fused_add_rms_norm" in proxy
​
    def test_vllm_aiter_namespace_auto_allowed(self):
        proxy = _VllmFallbackAllowList(set())
        # vllm_aiter:: 前缀操作也应自动允许
        assert "vllm_aiter::fused_add_rms_norm" in proxy
​
    def test_standard_entries_preserved(self):
        base = {"torchvision::roi_align", "aten::index_add"}
        proxy = _VllmFallbackAllowList(base)
        # 非 vllm 命名空间仍基于底层集合检查
        assert "torchvision::roi_align" in proxy
        assert "aten::index_add" in proxy
        assert "aten::__not_present__" not in proxy
​
    def test_add_and_discard_delegate_to_inner(self):
        inner: set[str] = set()
        proxy = _VllmFallbackAllowList(inner)
        proxy.add("custom::op")
        assert "custom::op" in inner # 操作影响内部集合
        proxy.discard("custom::op")
        assert "custom::op" not in inner

评论区精华

将补丁从 parallel_state.py 迁移到 env_override.py 设计

zou3519 建议将代码移到 env_override.py,因为那里已存在其他环境修补。作者同意并移动。

结论:已迁移。 · 已解决

是否可以为更改添加测试? 测试

zou3519 询问是否可能添加测试。作者回复已添加单元测试,并补充了端到端测试结果(编译时间从 3h+ 缩减到 5-7min)。

结论:测试已添加。 · 已解决

考虑是否有更好的方式告诉 Inductor 回退 设计

zou3519 表示想研究是否有更好的方法,但未提出具体方案。PR 最终采用当前代理包装方案。

结论:未发现更好的方法,PR 采用当前方案。 · 已解决

测试文件应放在 tests/compile 下 测试

zou3519 建议测试文件从 tests/ 根目录移到 tests/compile 下。作者照做。

结论:已移动。 · 已解决

风险与影响

该修补依赖 PyTorch 内部数据结构 torch._inductor.lowering.FALLBACK_ALLOW_LIST,若 PyTorch 未来更改此结构或引入新机制,则修补可能需要更新。修补是幂等的,但如果其他模块在修补前已导入 FALLBACK_ALLOW_LIST 并缓存了引用,则可能跳过修补(当前已处理 graph 模块的重新绑定)。另外,修补屏蔽了 vllm 操作的慢路径日志,若未来依赖该日志进行调试,可能会丢失信息。但总体风险较低,测试已覆盖关键场景。

对用户:修复了使用 torch.compile 时特定模型(尤其是大型 MoE 如 Kimi-K2.6)的编译挂起问题,使编译能在可接受时间内完成。对系统:无运行时性能影响,仅编译路径优化。对团队:维护成本低,代码集中在 env_override.py,并有完整测试覆盖。

依赖 PyTorch 内部 API 修改全局状态 可能屏蔽调试日志

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论