Prhub

#24096 Introduce CudaDeviceMixin and CudaSRTPlatform

原始 PR 作者 alexnails 合并时间 2026-05-16 01:59 文件变更 10 提交数 19 评论 12 代码增减 +318 / -45

执行摘要

引入 CudaDeviceMixin 与 CudaSRTPlatform 平台抽象层

SRT 运行时长期以来直接调用 torch.cuda 接口,导致平台紧密耦合 CUDA,难以扩展至 AMD ROCm 或其他硬件。该 PR 旨在构建平台抽象骨架,使设备操作多态化,并为后续 OOT(Out-of-Tree)插件机制奠定基础。Issue 评论中 alexnails 提到,此 PR 遵循与各硬件团队评审过的设计文档,但与其他提出的设计(如 #20372)略有不同。

值得精读。该 PR 是 SGLang 平台抽象层的关键基础设施,设计模式(Mixin + 自动发现、ROCm 继承 CUDA)对多硬件支持有借鉴意义。关注设备操作接口定义与回退逻辑,为后续 OOT 插件扩展提供参考。

讨论亮点

在 Review 中,AgainstEntropy 提出了多个关键意见:

  • CudaSRTPlatform 位置:建议将 CudaSRTPlatforminterface.py 移至 cuda.py,因为 interface.py 应只包含接口。该建议被 alexnails 接受并实施。
  • 方法覆盖完整性:指出 get_device_uuidseed_everything 在初始实现中未覆盖。alexnails 回应 get_device_uuid 因其他平台 UUID 格式不同而无法泛化,seed_everything 随后被移入 Mixin 中。
  • 能力标志覆盖:要求 CudaSRTPlatform 覆盖 supports_fp8support_cuda_graphsupport_piecewise_cuda_graph。最终实现中这些方法均返回 True
    • 此外,在提交历史中,作者曾回退 current_platform.empty_cache() 迁移,待 RocmSRTPlatform 就绪后才重新应用,以防止 AMD 路径静默无操作。

实现拆解

  1. 定义 CUDA 设备操作类:在 python/sglang/srt/platforms/cuda.py 中新增 CudaDeviceMixin 类,继承 DeviceMixin,实现所有与 CUDA 设备相关的操作方法,如 get_device_total_memoryget_current_memory_usageget_deviceset_deviceempty_cachesynchronize 等。同时定义 CudaSRTPlatform,继承 CudaDeviceMixinSRTPlatform,并覆盖 supports_fp8support_cuda_graph 等能力标志。

  2. ROCm 适配类:在 python/sglang/srt/platforms/rocm.py 中新增 RocmDeviceMixin,直接继承 CudaDeviceMixin,仅覆盖 _enumdevice_name,因为 PyTorch 中 HIP 设备仍然通过 torch.cuda API 暴露。RocmSRTPlatform 继承自 RocmDeviceMixinSRTPlatform,但保留默认的保守能力标志。

  3. 增强平台发现机制:修改 python/sglang/srt/platforms/__init__.py,添加 _is_cuda_available()_is_rocm_available() 辅助函数,分别检测纯 CUDA(torch.cuda.is_available()torch.version.hip is None)和 ROCm(torch.cuda.is_available()torch.version.hip is not None)。在 _resolve_platform() 回退路径中,当无插件激活时按顺序检测 CUDA/ROCm,并返回对应的 SRTPlatform 实例。

  4. 迁移 torch.cuda 调用:在多个模块中将硬编码的 torch.cuda.empty_cache()torch.cuda.synchronize() 替换为 current_platform.empty_cache()current_platform.synchronize()。涉及文件包括 loader.py(模型权重加载与卸载)、memory_pool.py(内存池管理)、scheduler.py(调度器)、model_runner.py(模型运行器)。此迁移在提交历史中曾回退后又重新应用,以确保 AMD 平台不会因缺少 RocmSRTPlatform 而静默无操作。

  5. 单元测试覆盖:在 test/registered/unit/platforms/test_platform_interface.py 中新增 TestCudaDeviceMixin 类,使用 mock 验证 CudaSRTPlatform 各方法委托到正确的 torch.cuda 函数。同时添加基础平台身份测试 test_base_device_identity_stays_unspecified,验证 SRTPlatform 基类不声称任何具体设备。

文件 模块 状态 重要度
python/sglang/srt/platforms/cuda.py 平台层 added 8.67
python/sglang/srt/platforms/rocm.py ROCm 适配 added 7.72
test/registered/unit/platforms/test_platform_interface.py 测试 modified 7.37
python/sglang/srt/platforms/__init__.py 平台发现 modified 7.16
python/sglang/srt/platforms/device_mixin.py 基类 modified 6.02
python/sglang/srt/model_loader/loader.py 模型加载 modified 5.98

关键符号

CudaDeviceMixin.get_device_total_memory CudaDeviceMixin.get_current_memory_usage CudaDeviceMixin.get_device CudaDeviceMixin.set_device CudaDeviceMixin.get_device_name CudaDeviceMixin.get_device_uuid CudaDeviceMixin.get_device_capability CudaDeviceMixin.empty_cache CudaDeviceMixin.synchronize CudaDeviceMixin.get_available_memory CudaDeviceMixin.seed_everything RocmDeviceMixin RocmSRTPlatform _resolve_platform _is_cuda_available _is_rocm_available

关键源码片段

python/sglang/srt/platforms/cuda.py core-logic

核心新增文件,定义了 CudaDeviceMixin 和 CudaSRTPlatform,是平台抽象层的基础。

class CudaDeviceMixin(DeviceMixin):
    '''CUDA implementation of the shared device operations.'''
    _enum = PlatformEnum.CUDA
    device_name = 'cuda'
    device_type = 'cuda'
​
    def get_device_total_memory(self, device_id=0):
        # 获取指定设备的总显存
        return int(torch.cuda.get_device_properties(device_id).total_memory)
​
    def get_current_memory_usage(self, device=None):
        # 获取当前显存使用量(峰值分配)
        return float(torch.cuda.max_memory_allocated(device))
​
    def get_device(self, local_rank):
        # 根据 local_rank 构造 CUDA 设备对象
        return torch.device('cuda', local_rank)
​
    def set_device(self, device):
        torch.cuda.set_device(device)
​
    def get_device_name(self, device_id=0):
        return str(torch.cuda.get_device_name(device_id))
​
    def get_device_uuid(self, device_id=0):
        # UUID 格式因平台而异,此处保持 CUDA 特定实现
        return str(torch.cuda.get_device_properties(device_id).uuid)
​
    def get_device_capability(self, device_id=0):
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major, minor)
​
    def empty_cache(self):
        torch.cuda.empty_cache()
​
    def synchronize(self):
        torch.cuda.synchronize()
​
    def get_available_memory(self, device_id=0):
        return torch.cuda.mem_get_info(device_id)
​
    def get_torch_distributed_backend_str(self):
        return 'nccl'
​
    @classmethod
    def seed_everything(cls, seed=None):
        if seed is not None:
            super().seed_everything(seed)
            torch.cuda.manual_seed_all(seed)
​
​
class CudaSRTPlatform(CudaDeviceMixin, SRTPlatform):
    '''Default in-tree CUDA SRT platform.'''
    def supports_fp8(self) -> bool:
        return True
​
    def support_cuda_graph(self) -> bool:
        return True
​
    def support_piecewise_cuda_graph(self) -> bool:
        return True
python/sglang/srt/platforms/rocm.py core-logic

新增 ROCm 适配类,继承 CUDA 设备操作实现,展示了平台抽象层的设计模式。

'''ROCm device operations for the SRT platform layer.PyTorch exposes ROCm through the same torch.cuda.* API surface as CUDA
(HIP is a binary shim, and torch.device('rocm') does not exist). So
RocmDeviceMixin inherits all device ops from CudaDeviceMixin and
only overrides identity (_enum, device_name).
'''from sglang.srt.platforms.cuda import CudaDeviceMixin
from sglang.srt.platforms.device_mixin import PlatformEnum
from sglang.srt.platforms.interface import SRTPlatform
​
​
class RocmDeviceMixin(CudaDeviceMixin):
    '''ROCm device ops — identical surface to CUDA via torch.cuda's HIP shim.'''
​
    _enum: PlatformEnum = PlatformEnum.ROCM
    device_name: str = 'rocm'
    # device_type stays 'cuda' — torch.device('cuda') is the only valid
    # device-type string for HIP devices in PyTorch.
​
​
class RocmSRTPlatform(RocmDeviceMixin, SRTPlatform):
    '''Default in-tree ROCm SRT platform.    Capability flags (supports_fp8, support_cuda_graph, support_piecewise_cuda_graph)
    keep the conservative SRTPlatform defaults rather than mirroring CudaSRTPlatform.
    They are currently only consulted in OOT branches gated on is_out_of_tree(),
    so the defaults are behaviorally inert for the in-tree ROCm path. A follow-up
    that migrates AMD-specific gating off legacy is_hip() should set these here.
    '''

评论区精华

CudaSRTPlatform 定义位置 设计

AgainstEntropy 建议将 CudaSRTPlatform 从 interface.py 移到 cuda.py,因为 interface.py 应只包含接口。

结论:alexnails 接受并移动,最终 CudaSRTPlatform 定义在 cuda.py。 · 已解决

get_device_uuid 和 seed_everything 覆盖 正确性

AgainstEntropy 指出 get_device_uuid、seed_everything 等方法在 CudaDeviceMixin 中未覆盖。alexnails 回应 get_device_uuid 不能泛化到其他平台(UUID 格式不同),seed_everything 将移入。

结论:seed_everything 被实现;get_device_uuid 保留在 CUDA 专用类中,不在 DeviceMixin 基类强制。 · 已解决

能力标志覆盖(supports_fp8 等) 设计

AgainstEntropy 要求 CudaSRTPlatform 覆盖 supports_fp8、support_cuda_graph、support_piecewise_cuda_graph。

结论:alexnails 实现,在 cuda.py 中 CudaSRTPlatform 返回 True。 · 已解决

风险与影响

平台误判风险:自动检测依赖 torch.cuda.is_available()torch.version.hip,但某些环境可能同时存在 CUDA 和 ROCm 驱动(罕见),或通过模拟设备导致误判。建议保留 SGLANG_PLATFORM 环境变量作为覆盖。

迁移完整性风险:虽然迁移了 empty_cachesynchronize,但其他 torch.cuda 函数可能未被替换(例如直接使用 torch.cuda.current_device()),存在残留硬编码。

OOT 平台兼容性:当前 in-tree 的 CudaSRTPlatformRocmSRTPlatform 与 OOT 插件共存逻辑可能引入优先级问题,若 OOT 插件与 in-tree 平台同时激活,len(activated)>1 会报错,要求设置 SGLANG_PLATFORM

用户影响:无功能性变化,向后兼容。用户无需变更配置,但未来可通过 SGLANG_PLATFORM 环境变量选择平台。

系统影响:架构上解耦了设备操作,便于添加新硬件支持。但需确保所有 CUDA 调用点被枚举替换。

团队影响:需要维护设备操作接口,新增平台只需实现 DeviceMixin 子类和 SRTPlatform 子类。

核心路径变更 平台兼容性 迁移完整性 自动检测假设

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论