执行摘要
- 一句话:将权重更新 mixin 迁移至独立组件
- 推荐动作:此 PR 是重构系列的关键一步,值得深入阅读以理解如何将 mixin 模式过渡到组合模式。review 中提到的三点代码质量问题应在后续 PR 中修复;读者亦可学习在机械迁移中如何保证方法体不变。
功能与动机
继续重构 scheduler,将庞大的 mixin 层逐块抽取到独立组件中,使权重更新功能内聚于 SchedulerWeightUpdaterManager,减少 Scheduler 类的多重继承负担。
实现拆解
-
在 SchedulerWeightUpdaterManager 中添加新方法:将 SchedulerUpdateWeightsMixin 中的 @staticmethod 方法(如 update_weights_from_disk)复制到 weight_updater.py 的类中,去除 @staticmethod 并将签名从 self: "SchedulerWeightUpdaterManager" 简化为 self。同时复制模块级辅助函数 _export_static_state / _import_static_state 到 weight_updater.py 的模块作用域。
-
添加必要的导入:将 logging、traceback、torch、Tuple 以及 io_struct 中的请求/响应类型加入 weight_updater.py。
-
删除原文件:删除 scheduler_update_weights_mixin.py。
-
更新 scheduler.py:移除对 SchedulerUpdateWeightsMixin 的导入和继承,在请求分发器中将 lambda req: self.<method>(self.weight_updater, req) 替换为直接的方法引用 self.weight_updater.<method>(如 self.weight_updater.update_weights_from_disk)。
-
配套调整:update save_remote_model 和 save_sharded_model 的调用方式(原通过类名调用,现通过实例调用)。
涉及文件:scheduler_update_weights_mixin.py(删除)、weight_updater.py(修改)、scheduler.py(修改)。
关键文件:
python/sglang/srt/managers/scheduler_update_weights_mixin.py(模块 权重更新;类别 source;类型 deletion;符号 SchedulerUpdateWeightsMixin, flush_cache_after_weight_update, update_weights_from_disk, init_weights_update_group): 这是被删除的源文件,包含旧有的 SchedulerUpdateWeightsMixin。该 Mixin 中的所有静态方法都被迁移至 SchedulerWeightUpdaterManager,因此该文件不再需要。其删除是整个 PR 的核心。
python/sglang/srt/managers/scheduler_components/weight_updater.py(模块 权重更新;类别 source;类型 dependency-wiring;符号 SchedulerWeightUpdaterManager, flush_cache_after_weight_update, update_weights_from_disk, init_weights_update_group): 这是迁移目标文件,SchedulerWeightUpdaterManager 类获得了所有权重更新的实例方法,并补充了必要的导入。此类替代了原来的 mixin,成为权重更新功能的唯一入口。
python/sglang/srt/managers/scheduler.py(模块 调度器;类别 source;类型 dependency-wiring): 调度器主文件,更新了导入和类继承,简化了请求分发注册。
关键符号:flush_cache_after_weight_update, update_weights_from_disk, init_weights_update_group, destroy_weights_update_group, update_weights_from_distributed, update_weights_from_tensor, update_weights_from_ipc, get_weights_by_name, release_memory_occupation, resume_memory_occupation, check_weights, save_remote_model, save_sharded_model, _export_static_state, _import_static_state
关键源码片段
python/sglang/srt/managers/scheduler_update_weights_mixin.py
这是被删除的源文件,包含旧有的 SchedulerUpdateWeightsMixin。该 Mixin 中的所有静态方法都被迁移至 SchedulerWeightUpdaterManager,因此该文件不再需要。其删除是整个 PR 的核心。
# file: scheduler_update_weights_mixin.py ( 已删除 )
# 原始 mixin 类,所有方法均为 @staticmethod,
# 通过 self 参数接收 SchedulerWeightUpdaterManager 实例。
import logging
import traceback
from typing import TYPE_CHECKING, Tuple
import torch
from sglang.srt.constants import (
GPU_MEMORY_ALL_TYPES,
GPU_MEMORY_TYPE_CUDA_GRAPH,
GPU_MEMORY_TYPE_KV_CACHE,
GPU_MEMORY_TYPE_WEIGHTS,
)
from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
DestroyWeightsUpdateGroupReqInput,
DestroyWeightsUpdateGroupReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromIPCReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
CheckWeightsReqInput,
CheckWeightsReqOutput,
)
logger = logging.getLogger(__name__)
class SchedulerUpdateWeightsMixin:
@staticmethod
def flush_cache_after_weight_update(
self: "SchedulerWeightUpdaterManager", recv_req
) -> None:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache(
empty_cache=recv_req.torch_empty_cache
)
assert flush_cache_success, "Cache flush failed after updating weights"
@staticmethod
def update_weights_from_disk(
self: "SchedulerWeightUpdaterManager", recv_req: UpdateWeightFromDiskReqInput
):
success, message = self.tp_worker.update_weights_from_disk(recv_req)
tp_success = success
if success and self.draft_worker is not None:
success, message = self.draft_worker.update_weights_from_disk(recv_req)
if tp_success:
SchedulerUpdateWeightsMixin.flush_cache_after_weight_update(self, recv_req)
if not success:
logger.error(message)
return UpdateWeightFromDiskReqOutput(success, message, 0)
# 其他方法类似
python/sglang/srt/managers/scheduler_components/weight_updater.py
这是迁移目标文件,SchedulerWeightUpdaterManager 类获得了所有权重更新的实例方法,并补充了必要的导入。此类替代了原来的 mixin,成为权重更新功能的唯一入口。
# file: weight_updater.py ( 迁移后 )
# SchedulerWeightUpdaterManager 现在拥有所有权重更新方法,
# 它们从 @staticmethod 转换为普通实例方法。
import logging
import traceback
from typing import Any, Callable, Tuple # Tuple 实际上在类型修正后不再使用
import torch
from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
DestroyWeightsUpdateGroupReqInput,
DestroyWeightsUpdateGroupReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromIPCReqOutput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
CheckWeightsReqInput,
CheckWeightsReqOutput,
)
logger = logging.getLogger(__name__)
@dataclass(kw_only=True, slots=True)
class SchedulerWeightUpdaterManager:
# 原有字段
tp_worker: Any
draft_worker: Any
tp_cpu_group: Any
memory_saver_adapter: Any
flush_cache: Callable[..., bool]
is_fully_idle: Callable[..., bool]
offload_tags: set = field(default_factory=set)
stashed_model_static_state: Any = None
def flush_cache_after_weight_update(self, recv_req) -> None:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache(
empty_cache=recv_req.torch_empty_cache
)
assert flush_cache_success, "Cache flush failed after updating weights"
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.tp_worker.update_weights_from_disk(recv_req)
tp_success = success
if success and self.draft_worker is not None:
success, message = self.draft_worker.update_weights_from_disk(recv_req)
if tp_success:
self.flush_cache_after_weight_update(recv_req)
if not success:
logger.error(message)
return UpdateWeightFromDiskReqOutput(success, message, 0)
# 注意:update_weights_from_distributed 的返回类型标注为 Tuple[bool, str] 与实际返回的 UpdateWeightsFromDistributedReqOutput 不符(review 指出)。
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
) -> Tuple[bool, str]: # 应改为 UpdateWeightsFromDistributedReqOutput
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
if success:
self.flush_cache_after_weight_update(recv_req)
else:
logger.error(message)
return UpdateWeightsFromDistributedReqOutput(success, message)
评论区精华
在 review 中 gemini-code-assist[bot] 指出了三个问题:
风险与影响
关联脉络
- PR #25624 Move invariant checks to SchedulerInvariantChecker and retire runtime_checker mixin: 同系列重构 PR,将不变量检查从 mixin 迁移到独立组件,与本次权重更新迁移类似。
- PR #25635 Move output streaming to SchedulerOutputStreamer: 同系列重构 PR,将输出流从 mixin 迁移到独立组件,展示相同的迁移模式。
- PR #25639 Delete the now-unused is_work_request from scheduler.py: 同系列清理 PR,删除不再使用的函数,与本次删除旧 mixin 的清理目标一致。
参与讨论