Prhub

#21388 Multi platform Plugin

原始 PR 作者 Baidu-AIAK 合并时间 2026-04-20 08:23 文件变更 22 提交数 22 评论 97 代码增减 +2811 / -17

执行摘要

引入统一插件框架和跨硬件平台抽象层

来自 PR body:硬件供应商和高级用户需要一种无需 fork 或修改主仓库即可扩展 SGLang 的方式。受 vLLM 平台抽象启发,该框架提供非侵入式、零配置的 OOT 硬件平台发现和函数级/类级的钩子注入能力。

强烈建议精读。该 PR 是 SGLang 迈向多平台生态的关键一步,其设计决策(Mixin 组合、钩子注册表、惰性平台发现)值得学习。特别关注 hook_registry.py 中钩子装饰器的实现和 platforms/__init__.py 中的发现流程,以及如何在现有代码中最小侵入地引入扩展点。文档 plugin.md 也是快速上手的良好参考。

讨论亮点
  1. 类替换统一到 HookRegistry:alexnails 建议将 class_replacer.py 合并到 HookRegistry,通过 HookType.REPLACE 统一处理。Baidu-AIAK 已实施,删除独立文件。
  2. 简化 API 为单一调用:merrymercy 要求减少调用点,Baidu-AIAK 将 load_plugins() 设计为内部自动执行 apply_hooks(),调用者只需一次函数调用。
  3. 装饰器 vs 命令式风格:alexnails 提议添加 @plugin_hook 装饰器使钩子注册更简洁,该功能已加入。
  4. 入口点命名:merrymercy 和 alexnails 认为 sglang.general_plugins 命名易混淆,最终改为 sglang.srt.plugins
  5. Mixin 而非钻石继承:alexnails 建议使用 DeviceMixin 避免 SRTPlatform 和 MMPlatform 的钻石继承,Baidu-AIAK 采纳并实现。
  6. is_out_of_tree() 分支的过渡性:多位 reviewer 希望避免 OOT 分支散落各处,Baidu-AIAK 承认这是过渡方案,并在文档中规划了未来统一的分发方向。
  7. 子进程钩子重新加载:merrymercy 指出 spawn 子进程丢失主进程状态,需在子进程入口调用 load_plugins(),已实现在 run_scheduler_process 中。
  8. 测试约定:AgainstEntropy 要求使用 CustomTestCase 替代 unittest.TestCase,已遵循。

实现拆解

  1. 核心抽象层
    • sglang/srt/platforms/device_mixin.py 中定义 DeviceMixin 基类,封装设备身份查询(is_cuda/is_npu等)和基础操作(get_device_total_memory 等)。
    • sglang/srt/platforms/interface.py 中定义 SRTPlatform 子类,添加 LLM 推理所需的工厂方法(get_default_attention_backend、get_graph_runner_cls 等)、能力标志(support_cuda_graph 等)和配置生命周期钩子。
    • OOT 平台通过多重继承 MySRTPlatform(SRTPlatform, MyDeviceMixin) 组合实现。
  2. 平台发现
    • sglang/srt/platforms/__init__.py 中实现 current_platform 惰性单例。优先使用 SGLANG_PLATFORM 环境变量选择指定插件(前端过滤,避免导入未选中的依赖);未设置时自动发现所有 entry_points 并期望恰好一个激活。
    • 注册组:sglang.srt.platforms(硬件平台)和 sglang.srt.plugins(通用功能)。
  3. 钩子系统
    • sglang/srt/plugins/hook_registry.py 中实现 HookRegistry,支持四种钩子类型:BEFORE、AFTER、AROUND、REPLACE。
    • 提供 plugin_hook 装饰器,支持同时注册函数钩子和类替换。类替换通过直接 setattr 实现,保留 isinstance/issubclass 语义。
    • 应用顺序:REPLACE 始终在围绕/前置/后置钩子之前应用,避免冲突。
  4. 加载与集成
    • load_plugins() 是核心入口,在 cli/serve.pylaunch_server.pyentrypoints/engine.py 的主进程和 managers/scheduler.py 的子进程中调用(子进程因 spawn 模式需重新加载)。
    • load_plugins() 自动在插件执行后调用 HookRegistry.apply_hooks(),调用者只需调用一次。
  5. OOT 集成点
    • 在 model_runner、memory_pool、server_args、multi_platform、compilation/backend、utils/common 中通过 if current_platform.is_out_of_tree() 分支提供 OOT 自定义入口。这些分支是过渡方案,未来将统一为纯平台分发。
  6. 测试与文档
    • 三个测试文件覆盖钩子语义、平台接口和插件加载逻辑,共 21 个测试用例。
    • 新增 docs/platforms/plugin.md,包含完整的使用指南和插件开发示例。
文件 模块 状态 重要度
python/sglang/srt/plugins/hook_registry.py 插件框架 added 9.25
python/sglang/srt/platforms/device_mixin.py 设备抽象 added 9.24
python/sglang/srt/platforms/interface.py 平台接口 added 8.96
python/sglang/srt/plugins/__init__.py 插件加载 added 8.69
python/sglang/srt/platforms/__init__.py 平台发现 added 8.64
test/registered/unit/platforms/test_platform_interface.py 平台测试 added 8.14
test/registered/unit/plugins/test_hook_registry.py 钩子系统测试 added 8.14
test/registered/unit/plugins/test_load_plugins.py 插件加载测试 added 8.01
python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py KV 缓存 modified 7.38
python/sglang/srt/model_executor/model_runner.py 模型运行器 modified 7.16
python/sglang/srt/layers/utils/multi_platform.py 多平台操作 modified 6.85
docs/platforms/plugin.md 平台文档 added 6.76

关键符号

load_plugins load_plugins_by_group _get_excluded_dists _resolve_platform _load_platform_class HookRegistry.register HookRegistry.apply_hooks HookRegistry._target_sort_key SRTPlatform.get_default_attention_backend SRTPlatform.get_graph_runner_cls SRTPlatform.get_mha_kv_pool_cls SRTPlatform.get_mla_kv_pool_cls SRTPlatform.get_nsa_kv_pool_cls SRTPlatform.get_paged_allocator_cls SRTPlatform.support_cuda_graph SRTPlatform.support_piecewise_cuda_graph DeviceMixin.is_cuda DeviceMixin.is_npu DeviceMixin.is_out_of_tree DeviceMixin.get_device_total_memory DeviceMixin.get_current_memory_usage register_oot_forward

关键源码片段

python/sglang/srt/plugins/hook_registry.py dependency-wiring

钩子注册表核心实现,定义 HookType 枚举、HookRegistry 类(注册、应用、排序),以及 `plugin_hook` 装饰器。是通用功能插件的基石。

"""
Hook registry for SGLang plugins.Provides before/after/around/replace hooks that can be applied to any
function, method, or class in the sglang codebase.
"""
import contextvars
import logging
from collections import defaultdict
from collections.abc import Callable
from enum import Enum
from typing import NamedTuplelogger = logging.getLogger(__name__)# 用于追踪当前正在加载的插件来源,确保 register() 可以自动关联
_current_plugin_source: contextvars.ContextVar[HookSource | None] = (
    contextvars.ContextVar("_current_plugin_source", default=None)
)
​
​
class HookSource(NamedTuple):
    plugin_name: str
    dist_name: str | None
​
​
class HookType(Enum):
    BEFORE = "before"
    AFTER = "after"
    AROUND = "around"
    REPLACE = "replace"
​
​
class HookRegistry:
    """全局钩子注册表。所有注册应在 load_plugins() 期间完成(单线程)。"""
​
    _hooks: dict[str, list[tuple[HookType, Callable, HookSource | None]]] = defaultdict(list)
    _patched: set[str] = set()
​
    @classmethod
    def register(
        cls,
        target: str, # 完整限定名,如 "sglang.srt.managers.scheduler.Scheduler.schedule"
        hook: Callable,
        hook_type: HookType = HookType.AFTER,
        *,
        source: HookSource | None = None,
    ):
        # 类对象只能用于 REPLACE 类型
        if isinstance(hook, type) and hook_type != HookType.REPLACE:
            raise TypeError(
                f"Class {hook.__name__} can only be used with HookType.REPLACE"
            )
        resolved_source = source or _current_plugin_source.get()
        cls._hooks[target].append((hook_type, hook, resolved_source))
​
    @classmethod
    def apply_hooks(cls):
        """应用所有已注册的钩子。按 REPLACE -> AROUND/BEFORE/AFTER 顺序,
        确保类替换先于函数包裹执行。"""
        targets = sorted(cls._hooks.keys(), key=lambda t: not any(
            ht == HookType.REPLACE for ht, _, _ in cls._hooks[t]
        ))
        for target in targets:
            if target in cls._patched:
                continue
            obj = _resolve_obj(target) # 通过 pkgutil.resolve_name 解析
            hooks = cls._hooks[target]
            # 分离 REPLACE 钩子
            replace_hooks = [(ht, h, s) for ht, h, s in hooks if ht == HookType.REPLACE]
            other_hooks = [(ht, h, s) for ht, h, s in hooks if ht != HookType.REPLACE]
            # 应用类替换
            for _, hook, source in replace_hooks:
                _apply_replace(target, hook, source)
            # 应用包裹钩子
            for hook_type, hook, source in other_hooks:
                _apply_wrap(target, obj, hook, hook_type, source)
            cls._patched.add(target)
python/sglang/srt/platforms/device_mixin.py dependency-wiring

定义 DeviceMixin 基类和 PlatformEnum,为所有平台提供统一的设备身份查询和操作接口,是 SRTPlatform 和未来 MMPlatform 的共同基础。

"""
Shared device abstraction for SGLang platforms.DeviceMixin 提供了设备身份查询和基础操作,
可由 SRTPlatform 和 MMPlatform 共同继承。
"""
import enum
from typing import TYPE_CHECKING, NamedTuple, Optionalif TYPE_CHECKING:
    import torch
​
​
class PlatformEnum(enum.Enum):
    # 已知平台列表,OOT 代表外部注册的插件
    CUDA = enum.auto()
    ROCM = enum.auto()
    CPU = enum.auto()
    XPU = enum.auto()
    MUSA = enum.auto()
    NPU = enum.auto()
    TPU = enum.auto()
    MPS = enum.auto()
    OOT = enum.auto() # 外部插件
    UNSPECIFIED = enum.auto()
​
​
class DeviceCapability(NamedTuple):
    major: int
    minor: int
​
    def as_version_str(self) -> str:
        return f"{self.major}.{self.minor}"
​
    def to_int(self) -> int:
        assert 0 <= self.minor < 10
        return self.major * 10 + self.minor
​
​
class DeviceMixin:
    """设备身份查询和基础操作混入类。子类通过 _enum 标识自己。"""
​
    _enum: PlatformEnum = PlatformEnum.UNSPECIFIED
    device_name: str = "unknown"
    device_type: str = "cpu"
​
    # ---- 身份查询 ----
    def is_cuda(self) -> bool:
        return self._enum == PlatformEnum.CUDA
​
    def is_rocm(self) -> bool:
        return self._enum == PlatformEnum.ROCM
​
    def is_npu(self) -> bool:
        return self._enum == PlatformEnum.NPU
​
    def is_cuda_alike(self) -> bool:
        # CUDA、ROCm、MUSA 被认为具有类似 CUDA 的 API
        return self._enum in (
            PlatformEnum.CUDA,
            PlatformEnum.ROCM,
            PlatformEnum.MUSA,
        )
​
    def is_out_of_tree(self) -> bool:
        return self._enum == PlatformEnum.OOT
​
    # ---- 设备操作(抽象,OOT 必须实现)----
    def get_device_total_memory(self, device_id: int = 0) -> int:
        raise NotImplementedError
​
    def get_current_memory_usage(self, device: Optional["torch.device"] = None) -> float:
        raise NotImplementedError

评论区精华

ClassReplacer 合并到 HookRegistry 设计

alexnails 建议将 ClassReplacer 合并到 HookRegistry,通过 HookType.REPLACE 统一处理,避免重复的注册逻辑。

结论:Baidu-AIAK 接受建议,删除了独立 class_replacer.py,将类替换作为 HookType 的一部分集成。 · 已解决

简化插件加载 API 为单一调用 设计

merrymercy 要求将 apply_worker_patches() 等分开的调用简化为单一函数调用。

结论:Baidu-AIAK 修改 load_plugins() 使其内部自动调用 HookRegistry.apply_hooks(),调用者只需调用一次。 · 已解决

入口点命名从 general_plugins 改为 plugins style

merrymercy 和 alexnails 认为 sglang.general_plugins 命名易混淆,建议缩短。

结论:改为 sglang.srt.plugins 和 sglang.srt.platforms,与包结构一致。 · 已解决

DeviceMixin 设计以避免钻石继承 设计

alexnails 指出原始单继承模式无法被多模态平台平滑复用,建议采用 Mixin 模式让 SRTPlatform 和 MMPlatform 共享 DeviceMixin。

结论:Baidu-AIAK 采纳,将平台拆分为 DeviceMixin + SRTPlatform,并预留 MMPlatform 插槽。 · 已解决

is_out_of_tree() 分支的过渡性质 设计

多位 reviewer(alexnails)担心 OOT 分支散落各处,希望最终统一通过 current_platform 分发。

结论:Baidu-AIAK 承认这是过渡方案,文档中已规划未来将内置平台也迁移到插件体系。当前保持最小侵入。 · acknowledged

子进程钩子重新加载 正确性

merrymercy 指出 spawn 子进程不继承主进程内存,需在子进程入口重新加载插件和应用钩子。

结论:在 run_scheduler_process 中调用 load_plugins() 并移除 tp_worker.py 中的冗余调用。 · 已解决

使用 CustomTestCase 替代 unittest.TestCase style

AgainstEntropy 指出需要遵循 SGLang 测试规范,使用 CustomTestCase。

结论:Baidu-AIAK 修改了测试文件以符合规范。 · 已解决

OOT 平台钩子 get_rope 不生效的 bug 报告 正确性

chenxb002 报告钩子虽然显示应用,但实际未生效。Baidu-AIAK 分析可能是钩子应用时机问题。

结论:该问题在 PR 合并后可能仍然存在,需后续跟踪修复(可能存在模块绑定过时的问题)。 · unresolved

风险与影响

  • 回归风险:OOT 分支(is_out_of_tree())在多个核心路径中新增了条件判断。如果 OOT 平台的行为不符合预期,可能影响默认的执行流程。但所有新分支仅在明确激活 OOT 平台时触发,对内置平台无影响。
  • 子进程重新加载钩子:spawn 子进程需再次调用 load_plugins(),若加载失败或顺序错误,可能导致钩子未生效。当前设计在 run_scheduler_process 中加载,但需确保异常不影响主进程。
  • 钩子应用顺序:REPLACE 钩子先于 AROUND/BEFORE/AFTER 执行,但若多个 REPLACE 作用于同一目标,仅最后一个生效(有日志警告)。OOT 开发者需注意冲突。
  • 依赖隐式加载load_plugins_by_group 会尝试加载所有发现的 entry_points,如果插件中有硬错误(如缺少硬件驱动),整个加载流程会捕获异常但不会阻止主进程运行。需确认异常处理不影响核心功能。
  • 缺少对内置平台的广泛测试:当前单元测试仅覆盖框架本身,未测试 OOT 集成点与现有功能的组合(如 CUDA 下使用 OOT 分支),后续 PR 需补充集成测试。
  • 对用户:现有用户无感知,除非他们安装并配置了 OOT 平台插件。使用 SGLANG_PLATFORMSGLANG_PLUGINS 环境变量的用户可以按需激活扩展。
  • 对系统:核心组件(ModelRunner、MemoryPool 等)新增了少量条件分支,但性能无影响(OOT 分支检查为简单布尔判断)。框架自身在未激活插件时开销极低。
  • 对团队:提供了标准的平台抽象和扩展机制,减少后续添加新硬件时的重复劳动。但需注意维护 OOT 集成点的兼容性。
  • 影响范围:涉及 22 个文件,新增约 2800 行代码,包括核心抽象、集成点、测试和文档。
核心路径变更 子进程重新加载 钩子应用顺序依赖 OOT API 不稳定 缺少回归测试

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论