Prhub

#24190 Bypass torch.cuda.use_mem_pool generator-CM in SymmetricMemoryContext

原始 PR 作者 ch-wan 合并时间 2026-05-01 16:25 文件变更 1 提交数 1 评论 2 代码增减 +11 / -11

执行摘要

绕过 torch.cuda.use_mem_pool 上下文管理器,直接调用底层 C API

原实现使用 torch.cuda.use_mem_pool 作为上下文管理器,在 CUDA graph 捕获场景下需要额外维护 exited 状态和重新获取 pool 的逻辑(不可重入)。直接调用底层 C 函数可以消除这些复杂性,并提高稳定性。PR body 未直接说明动机,但从代码变更可推断。

值得精读。该 PR 展示了如何通过绕过高层 API 来简化状态管理,是典型的性能优化和代码简化实践。但应关注 PyTorch 私有 API 的兼容性。

讨论亮点

该 PR 没有 review 评论。

实现拆解

  1. 导入底层 API:从 torch.cuda.memory 导入 _cuda_beginAllocateCurrentThreadToPool_cuda_endAllocateToPool_cuda_releasePool
  2. 重构 __init__:将 self._mem_pool_ctx 替换为 self._pool_idself._device_index,用于直接引用内存池和设备;移除 self.exited 标志。
  3. 重构 __enter__:移除 self.exited 相关逻辑和 self._mem_pool_ctx.__enter__() 调用,改为直接调用 _cuda_beginAllocateCurrentThreadToPool(self._device_index, self._pool_id)
  4. 重构 __exit__:不再调用 self._mem_pool_ctx.__exit__,而是直接调用 _cuda_endAllocateToPool_cuda_releasePool;移除 self.exited = True
  5. 保持 CUDA graph 逻辑__enter____exit__ 中与 graph capture 相关的暂停/恢复逻辑保持不变。
文件 模块 状态 重要度
python/sglang/srt/distributed/device_communicators/pynccl_allocator.py 通信层 modified 6.48

关键符号

SymmetricMemoryContext.__init__ SymmetricMemoryContext.__enter__ SymmetricMemoryContext.__exit__

关键源码片段

python/sglang/srt/distributed/device_communicators/pynccl_allocator.py dependency-wiring

唯一变更文件,重构了 SymmetricMemoryContext 的核心实现,从使用 torch.cuda.use_mem_pool 上下文管理器改为直接调用底层 C API。

# 从 torch.cuda.memory 导入底层 C API
from torch.cuda.memory import (
    CUDAPluggableAllocator,
    _cuda_beginAllocateCurrentThreadToPool,
    _cuda_endAllocateToPool,
    _cuda_releasePool,
)class SymmetricMemoryContext:
    """
    Context manager for using symmetric memory with pynccl.
    """
​
    def __init__(
        self,
        group_coordinator: GroupCoordinator,
    ):
        self.group_coordinator = group_coordinator
        # 直接存储 pool ID 和设备索引,而非使用上下文管理器
        self._pool_id = get_nccl_mem_pool().id
        self._device_index = torch.cuda.current_device()
        self.is_graph_capture = torch.cuda.is_current_stream_capturing()
        # 移除了 self._mem_pool_ctx 和 self.exited
​
    def __enter__(self):
        assert (
            self.group_coordinator.pynccl_comm is not None
        ), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
​
        if self.is_graph_capture:
            assert (
                _graph_pool_id is not None
            ), "graph_pool_id is not set under graph capture"
            if after_2_8_0:
                torch._C._cuda_endAllocateToPool(_cur_device, _graph_pool_id)
            else:
                torch._C._cuda_endAllocateCurrentStreamToPool(
                    _cur_device, _graph_pool_id
                )
​
        # 直接调用底层 API 开始分配,而非进入上下文管理器
        _cuda_beginAllocateCurrentThreadToPool(self._device_index, self._pool_id)
​
        os.environ["SGLANG_TMP_NCCL_COMM_VALUE"] = str(
            self.group_coordinator.pynccl_comm.comm.value
        )
​
        global _active_symmetric_memory_context
        _active_symmetric_memory_context = self
​
        return self
​
    def __exit__(self, exc_type, exc_val, exc_tb):
        # 直接调用底层 API 结束分配并释放 pool
        _cuda_endAllocateToPool(self._device_index, self._pool_id)
        _cuda_releasePool(self._device_index, self._pool_id)
​
        if self.is_graph_capture:
            if after_2_8_0:
                torch._C._cuda_beginAllocateCurrentThreadToPool(
                    _cur_device, _graph_pool_id
                )
            else:
                torch._C._cuda_beginAllocateToPool(_cur_device, _graph_pool_id)
​
        global _active_symmetric_memory_context
        _active_symmetric_memory_context = None
        # 移除了 self.exited = True

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

风险较低。变更仅限于 SymmetricMemoryContext 内部实现,对外接口不变。但直接调用私有 API(_cuda_beginAllocateCurrentThreadToPool 等)在不同 PyTorch 版本间可能不稳定,需在 after_2_8_0 条件分支中确认兼容性。此外,_cuda_releasePool 是新增调用,需确认在非 graph 模式下释放 pool 不会影响后续分配。缺少对应测试文件。

影响范围小。仅修改 pynccl_allocator.py 一个文件,只影响使用 SymmetricMemoryContext 的代码路径(如 use_symmetric_memory)。对用户透明,降低了维护复杂度和运行时开销。

缺少测试覆盖 依赖 PyTorch 私有 API

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论