执行摘要
- 一句话:引入统一插件框架和跨硬件平台抽象层
- 推荐动作:强烈建议精读。该 PR 是 SGLang 迈向多平台生态的关键一步,其设计决策(Mixin 组合、钩子注册表、惰性平台发现)值得学习。特别关注
hook_registry.py 中钩子装饰器的实现和 platforms/__init__.py 中的发现流程,以及如何在现有代码中最小侵入地引入扩展点。文档 plugin.md 也是快速上手的良好参考。
功能与动机
来自 PR body:硬件供应商和高级用户需要一种无需 fork 或修改主仓库即可扩展 SGLang 的方式。受 vLLM 平台抽象启发,该框架提供非侵入式、零配置的 OOT 硬件平台发现和函数级/类级的钩子注入能力。
实现拆解
- 核心抽象层:
- 在
sglang/srt/platforms/device_mixin.py 中定义 DeviceMixin 基类,封装设备身份查询(is_cuda/is_npu等)和基础操作(get_device_total_memory 等)。
- 在
sglang/srt/platforms/interface.py 中定义 SRTPlatform 子类,添加 LLM 推理所需的工厂方法(get_default_attention_backend、get_graph_runner_cls 等)、能力标志(support_cuda_graph 等)和配置生命周期钩子。
- OOT 平台通过多重继承 MySRTPlatform(SRTPlatform, MyDeviceMixin) 组合实现。
- 平台发现:
- 在
sglang/srt/platforms/__init__.py 中实现 current_platform 惰性单例。优先使用 SGLANG_PLATFORM 环境变量选择指定插件(前端过滤,避免导入未选中的依赖);未设置时自动发现所有 entry_points 并期望恰好一个激活。
- 注册组:
sglang.srt.platforms(硬件平台)和 sglang.srt.plugins(通用功能)。
- 钩子系统:
- 在
sglang/srt/plugins/hook_registry.py 中实现 HookRegistry,支持四种钩子类型:BEFORE、AFTER、AROUND、REPLACE。
- 提供
plugin_hook 装饰器,支持同时注册函数钩子和类替换。类替换通过直接 setattr 实现,保留 isinstance/issubclass 语义。
- 应用顺序:REPLACE 始终在围绕/前置/后置钩子之前应用,避免冲突。
- 加载与集成:
load_plugins() 是核心入口,在 cli/serve.py、launch_server.py、entrypoints/engine.py 的主进程和 managers/scheduler.py 的子进程中调用(子进程因 spawn 模式需重新加载)。
load_plugins() 自动在插件执行后调用 HookRegistry.apply_hooks(),调用者只需调用一次。
- OOT 集成点:
- 在 model_runner、memory_pool、server_args、multi_platform、compilation/backend、utils/common 中通过
if current_platform.is_out_of_tree() 分支提供 OOT 自定义入口。这些分支是过渡方案,未来将统一为纯平台分发。
- 测试与文档:
- 三个测试文件覆盖钩子语义、平台接口和插件加载逻辑,共 21 个测试用例。
- 新增
docs/platforms/plugin.md,包含完整的使用指南和插件开发示例。
关键文件:
python/sglang/srt/plugins/hook_registry.py(模块 插件框架;类别 source;类型 dependency-wiring;符号 my_timer, HookSource, _format_source, HookType): 钩子注册表核心实现,定义 HookType 枚举、HookRegistry 类(注册、应用、排序),以及 plugin_hook 装饰器。是通用功能插件的基石。
python/sglang/srt/platforms/device_mixin.py(模块 设备抽象;类别 source;类型 dependency-wiring;符号 PlatformEnum, CpuArchEnum, DeviceCapability, as_version_str): 定义 DeviceMixin 基类和 PlatformEnum,为所有平台提供统一的设备身份查询和操作接口,是 SRTPlatform 和未来 MMPlatform 的共同基础。
python/sglang/srt/platforms/interface.py(模块 平台接口;类别 source;类型 dependency-wiring;符号 SRTPlatform, apply_server_args_defaults, get_default_attention_backend, get_graph_runner_cls): 定义 SRTPlatform,继承 DeviceMixin 并添加 LLM 推理所需工厂方法、能力标志和配置生命周期钩子,是 OOT 平台插件的主要基类。
python/sglang/srt/plugins/__init__.py(模块 插件加载;类别 source;类型 dependency-wiring;符号 load_plugins_by_group, _get_excluded_dists, load_plugins): 插件发现和加载的核心入口,实现 load_plugins()、load_plugins_by_group() 和 _get_excluded_dists(),处理 setuptools entry_points 发现、白名单过滤和幂等性。
python/sglang/srt/platforms/__init__.py(模块 平台发现;类别 source;类型 dependency-wiring;符号 _resolve_platform, _load_platform_class, getattr): current_platform 惰性单例的实现,处理平台插件的前端过滤和自动发现流程,是框架使用的主要入口点。
test/registered/unit/platforms/test_platform_interface.py(模块 平台测试;类别 test;类型 test-coverage;符号 _make_device_mixin, M, get_device_total_memory, get_current_memory_usage): 平台接口单元测试,覆盖 DeviceMixin、SRTPlatform 的所有方法、PlatformEnum 枚举、平台发现和惰性初始化逻辑。
test/registered/unit/plugins/test_hook_registry.py(模块 钩子系统测试;类别 test;类型 test-coverage;符号 _make_module, _cleanup_synth_modules, _HookTestCase, setUp): 钩子注册表单元测试,覆盖 AROUND/BEFORE/AFTER/REPLACE 四种钩子类型,以及类方法/静态方法描述符保留、钩子排序、冲突检测和幂等性。
test/registered/unit/plugins/test_load_plugins.py(模块 插件加载测试;类别 test;类型 test-coverage;符号 _make_ep, _reset_plugins_loaded, TestLoadPlugins, setUp): 插件加载系统测试,验证 load_plugins 的幂等性、异常容错、SGLANG_PLUGINS 白名单过滤和 SGLANG_PLATFORM 排除逻辑。
python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py(模块 KV缓存;类别 source;类型 data-contract): OOT 分支修改最多的地方之一,基于 current_platform 为 OOT 平台选择不同的 KV cache 池类型(MHA/MLA/NSA)。
python/sglang/srt/model_executor/model_runner.py(模块 模型运行器;类别 source;类型 data-contract): 模型运行器集成 OOT 平台初始化、图捕获和权重加载逻辑,新增多个 current_platform 分支。
python/sglang/srt/layers/utils/multi_platform.py(模块 多平台操作;类别 source;类型 core-logic;符号 register_oot_forward): OOT 平台的多平台操作转发支持,通过 register_oot_forward 注册自定义 forward 方法。
docs/platforms/plugin.md(模块 平台文档;类别 docs;类型 documentation;符号 activate, MyDeviceMixin, set_device, get_device_name): 详细的技术文档,包含概述、插件类型、开发指南(创建 OOT 平台、挂钩函数、替换类)、架构图、环境变量说明和未来方向。
关键符号:load_plugins, load_plugins_by_group, _get_excluded_dists, _resolve_platform, _load_platform_class, HookRegistry.register, HookRegistry.apply_hooks, HookRegistry._target_sort_key, SRTPlatform.get_default_attention_backend, SRTPlatform.get_graph_runner_cls, SRTPlatform.get_mha_kv_pool_cls, SRTPlatform.get_mla_kv_pool_cls, SRTPlatform.get_nsa_kv_pool_cls, SRTPlatform.get_paged_allocator_cls, SRTPlatform.support_cuda_graph, SRTPlatform.support_piecewise_cuda_graph, DeviceMixin.is_cuda, DeviceMixin.is_npu, DeviceMixin.is_out_of_tree, DeviceMixin.get_device_total_memory, DeviceMixin.get_current_memory_usage, register_oot_forward
关键源码片段
python/sglang/srt/plugins/hook_registry.py
钩子注册表核心实现,定义 HookType 枚举、HookRegistry 类(注册、应用、排序),以及 plugin_hook 装饰器。是通用功能插件的基石。
"""
Hook registry for SGLang plugins.
Provides before/after/around/replace hooks that can be applied to any
function, method, or class in the sglang codebase.
"""
import contextvars
import logging
from collections import defaultdict
from collections.abc import Callable
from enum import Enum
from typing import NamedTuple
logger = logging.getLogger(__name__)
# 用于追踪当前正在加载的插件来源,确保 register() 可以自动关联
_current_plugin_source: contextvars.ContextVar[HookSource | None] = (
contextvars.ContextVar("_current_plugin_source", default=None)
)
class HookSource(NamedTuple):
plugin_name: str
dist_name: str | None
class HookType(Enum):
BEFORE = "before"
AFTER = "after"
AROUND = "around"
REPLACE = "replace"
class HookRegistry:
"""全局钩子注册表。所有注册应在 load_plugins() 期间完成(单线程)。"""
_hooks: dict[str, list[tuple[HookType, Callable, HookSource | None]]] = defaultdict(list)
_patched: set[str] = set()
@classmethod
def register(
cls,
target: str, # 完整限定名,如 "sglang.srt.managers.scheduler.Scheduler.schedule"
hook: Callable,
hook_type: HookType = HookType.AFTER,
*,
source: HookSource | None = None,
):
# 类对象只能用于 REPLACE 类型
if isinstance(hook, type) and hook_type != HookType.REPLACE:
raise TypeError(
f"Class {hook.__name__} can only be used with HookType.REPLACE"
)
resolved_source = source or _current_plugin_source.get()
cls._hooks[target].append((hook_type, hook, resolved_source))
@classmethod
def apply_hooks(cls):
"""应用所有已注册的钩子。按 REPLACE -> AROUND/BEFORE/AFTER 顺序,
确保类替换先于函数包裹执行。"""
targets = sorted(cls._hooks.keys(), key=lambda t: not any(
ht == HookType.REPLACE for ht, _, _ in cls._hooks[t]
))
for target in targets:
if target in cls._patched:
continue
obj = _resolve_obj(target) # 通过 pkgutil.resolve_name 解析
hooks = cls._hooks[target]
# 分离 REPLACE 钩子
replace_hooks = [(ht, h, s) for ht, h, s in hooks if ht == HookType.REPLACE]
other_hooks = [(ht, h, s) for ht, h, s in hooks if ht != HookType.REPLACE]
# 应用类替换
for _, hook, source in replace_hooks:
_apply_replace(target, hook, source)
# 应用包裹钩子
for hook_type, hook, source in other_hooks:
_apply_wrap(target, obj, hook, hook_type, source)
cls._patched.add(target)
python/sglang/srt/platforms/device_mixin.py
定义 DeviceMixin 基类和 PlatformEnum,为所有平台提供统一的设备身份查询和操作接口,是 SRTPlatform 和未来 MMPlatform 的共同基础。
"""
Shared device abstraction for SGLang platforms.
DeviceMixin 提供了设备身份查询和基础操作,
可由 SRTPlatform 和 MMPlatform 共同继承。
"""
import enum
from typing import TYPE_CHECKING, NamedTuple, Optional
if TYPE_CHECKING:
import torch
class PlatformEnum(enum.Enum):
# 已知平台列表,OOT 代表外部注册的插件
CUDA = enum.auto()
ROCM = enum.auto()
CPU = enum.auto()
XPU = enum.auto()
MUSA = enum.auto()
NPU = enum.auto()
TPU = enum.auto()
MPS = enum.auto()
OOT = enum.auto() # 外部插件
UNSPECIFIED = enum.auto()
class DeviceCapability(NamedTuple):
major: int
minor: int
def as_version_str(self) -> str:
return f"{self.major}.{self.minor}"
def to_int(self) -> int:
assert 0 <= self.minor < 10
return self.major * 10 + self.minor
class DeviceMixin:
"""设备身份查询和基础操作混入类。子类通过 _enum 标识自己。"""
_enum: PlatformEnum = PlatformEnum.UNSPECIFIED
device_name: str = "unknown"
device_type: str = "cpu"
# ---- 身份查询 ----
def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA
def is_rocm(self) -> bool:
return self._enum == PlatformEnum.ROCM
def is_npu(self) -> bool:
return self._enum == PlatformEnum.NPU
def is_cuda_alike(self) -> bool:
# CUDA、ROCm、MUSA 被认为具有类似 CUDA 的 API
return self._enum in (
PlatformEnum.CUDA,
PlatformEnum.ROCM,
PlatformEnum.MUSA,
)
def is_out_of_tree(self) -> bool:
return self._enum == PlatformEnum.OOT
# ---- 设备操作(抽象,OOT 必须实现)----
def get_device_total_memory(self, device_id: int = 0) -> int:
raise NotImplementedError
def get_current_memory_usage(self, device: Optional["torch.device"] = None) -> float:
raise NotImplementedError
评论区精华
- 类替换统一到 HookRegistry:alexnails 建议将
class_replacer.py 合并到 HookRegistry,通过 HookType.REPLACE 统一处理。Baidu-AIAK 已实施,删除独立文件。
- 简化 API 为单一调用:merrymercy 要求减少调用点,Baidu-AIAK 将
load_plugins() 设计为内部自动执行 apply_hooks(),调用者只需一次函数调用。
- 装饰器 vs 命令式风格:alexnails 提议添加
@plugin_hook 装饰器使钩子注册更简洁,该功能已加入。
- 入口点命名:merrymercy 和 alexnails 认为
sglang.general_plugins 命名易混淆,最终改为 sglang.srt.plugins。
- Mixin 而非钻石继承:alexnails 建议使用 DeviceMixin 避免 SRTPlatform 和 MMPlatform 的钻石继承,Baidu-AIAK 采纳并实现。
- is_out_of_tree() 分支的过渡性:多位 reviewer 希望避免 OOT 分支散落各处,Baidu-AIAK 承认这是过渡方案,并在文档中规划了未来统一的分发方向。
- 子进程钩子重新加载:merrymercy 指出 spawn 子进程丢失主进程状态,需在子进程入口调用
load_plugins(),已实现在 run_scheduler_process 中。
- 测试约定:AgainstEntropy 要求使用
CustomTestCase 替代 unittest.TestCase,已遵循。
- ClassReplacer 合并到 HookRegistry (design): Baidu-AIAK 接受建议,删除了独立 class_replacer.py,将类替换作为 HookType 的一部分集成。
- 简化插件加载 API 为单一调用 (design): Baidu-AIAK 修改 load_plugins() 使其内部自动调用 HookRegistry.apply_hooks(),调用者只需调用一次。
- 入口点命名从 general_plugins 改为 plugins (style): 改为 sglang.srt.plugins 和 sglang.srt.platforms,与包结构一致。
- DeviceMixin 设计以避免钻石继承 (design): Baidu-AIAK 采纳,将平台拆分为 DeviceMixin + SRTPlatform,并预留 MMPlatform 插槽。
- is_out_of_tree() 分支的过渡性质 (design): Baidu-AIAK 承认这是过渡方案,文档中已规划未来将内置平台也迁移到插件体系。当前保持最小侵入。
- 子进程钩子重新加载 (correctness): 在 run_scheduler_process 中调用 load_plugins() 并移除 tp_worker.py 中的冗余调用。
- 使用 CustomTestCase 替代 unittest.TestCase (style): Baidu-AIAK 修改了测试文件以符合规范。
- OOT 平台钩子 get_rope 不生效的 bug 报告 (correctness): 该问题在 PR 合并后可能仍然存在,需后续跟踪修复(可能存在模块绑定过时的问题)。
风险与影响
- 风险:
- 回归风险:OOT 分支(
is_out_of_tree())在多个核心路径中新增了条件判断。如果 OOT 平台的行为不符合预期,可能影响默认的执行流程。但所有新分支仅在明确激活 OOT 平台时触发,对内置平台无影响。
- 子进程重新加载钩子:spawn 子进程需再次调用
load_plugins(),若加载失败或顺序错误,可能导致钩子未生效。当前设计在 run_scheduler_process 中加载,但需确保异常不影响主进程。
- 钩子应用顺序:REPLACE 钩子先于 AROUND/BEFORE/AFTER 执行,但若多个 REPLACE 作用于同一目标,仅最后一个生效(有日志警告)。OOT 开发者需注意冲突。
- 依赖隐式加载:
load_plugins_by_group 会尝试加载所有发现的 entry_points,如果插件中有硬错误(如缺少硬件驱动),整个加载流程会捕获异常但不会阻止主进程运行。需确认异常处理不影响核心功能。
- 缺少对内置平台的广泛测试:当前单元测试仅覆盖框架本身,未测试 OOT 集成点与现有功能的组合(如 CUDA 下使用 OOT 分支),后续 PR 需补充集成测试。
- 影响:
- 对用户:现有用户无感知,除非他们安装并配置了 OOT 平台插件。使用
SGLANG_PLATFORM 或 SGLANG_PLUGINS 环境变量的用户可以按需激活扩展。
- 对系统:核心组件(ModelRunner、MemoryPool 等)新增了少量条件分支,但性能无影响(OOT 分支检查为简单布尔判断)。框架自身在未激活插件时开销极低。
- 对团队:提供了标准的平台抽象和扩展机制,减少后续添加新硬件时的重复劳动。但需注意维护 OOT 集成点的兼容性。
- 影响范围:涉及 22 个文件,新增约 2800 行代码,包括核心抽象、集成点、测试和文档。
- 风险标记:核心路径变更, 子进程重新加载, 钩子应用顺序依赖, OOT API 不稳定, 缺少回归测试
关联脉络
参与讨论