Prhub

#25616 Move weight-update RPC handlers to SchedulerWeightUpdaterManager

原始 PR 作者 fzyzcjy 合并时间 2026-05-18 18:34 文件变更 3 提交数 1 评论 3 代码增减 +225 / -280

执行摘要

将权重更新 mixin 迁移至独立组件

继续重构 scheduler,将庞大的 mixin 层逐块抽取到独立组件中,使权重更新功能内聚于 SchedulerWeightUpdaterManager,减少 Scheduler 类的多重继承负担。

此 PR 是重构系列的关键一步,值得深入阅读以理解如何将 mixin 模式过渡到组合模式。review 中提到的三点代码质量问题应在后续 PR 中修复;读者亦可学习在机械迁移中如何保证方法体不变。

讨论亮点

在 review 中 gemini-code-assist[bot] 指出了三个问题:

  • unused Tuple import(medium):Tuple 在修正返回类型后不再使用。
  • 不正确的返回类型提示:update_weights_from_distributed 声明返回 Tuple[bool,str] 但实际返回 UpdateWeightsFromDistributedReqOutput。
  • resume_memory_occupation 中使用 remove 而非 discard 从 offload_tags 集合中删除,可能导致 KeyError。
    这些意见均未被作者回复或解决,PR 已合并。

实现拆解

  1. 在 SchedulerWeightUpdaterManager 中添加新方法:将 SchedulerUpdateWeightsMixin 中的 @staticmethod 方法(如 update_weights_from_disk)复制到 weight_updater.py 的类中,去除 @staticmethod 并将签名从 self: "SchedulerWeightUpdaterManager" 简化为 self。同时复制模块级辅助函数 _export_static_state / _import_static_state 到 weight_updater.py 的模块作用域。

  2. 添加必要的导入:将 logging、traceback、torch、Tuple 以及 io_struct 中的请求/响应类型加入 weight_updater.py。

  3. 删除原文件:删除 scheduler_update_weights_mixin.py。

  4. 更新 scheduler.py:移除对 SchedulerUpdateWeightsMixin 的导入和继承,在请求分发器中将 lambda req: self.<method>(self.weight_updater, req) 替换为直接的方法引用 self.weight_updater.<method>(如 self.weight_updater.update_weights_from_disk)。

  5. 配套调整:update save_remote_modelsave_sharded_model 的调用方式(原通过类名调用,现通过实例调用)。

涉及文件:scheduler_update_weights_mixin.py(删除)、weight_updater.py(修改)、scheduler.py(修改)。

文件 模块 状态 重要度
python/sglang/srt/managers/scheduler_update_weights_mixin.py 权重更新 removed 9.28
python/sglang/srt/managers/scheduler_components/weight_updater.py 权重更新 modified 8.88
python/sglang/srt/managers/scheduler.py 调度器 modified 6.39

关键符号

flush_cache_after_weight_update update_weights_from_disk init_weights_update_group destroy_weights_update_group update_weights_from_distributed update_weights_from_tensor update_weights_from_ipc get_weights_by_name release_memory_occupation resume_memory_occupation check_weights save_remote_model save_sharded_model _export_static_state _import_static_state

关键源码片段

python/sglang/srt/managers/scheduler_update_weights_mixin.py deletion

这是被删除的源文件,包含旧有的 SchedulerUpdateWeightsMixin。该 Mixin 中的所有静态方法都被迁移至 SchedulerWeightUpdaterManager,因此该文件不再需要。其删除是整个 PR 的核心。

# file: scheduler_update_weights_mixin.py ( 已删除 )
# 原始 mixin 类,所有方法均为 @staticmethod,
# 通过 self 参数接收 SchedulerWeightUpdaterManager 实例。import logging
import traceback
from typing import TYPE_CHECKING, Tupleimport torch
from sglang.srt.constants import (
    GPU_MEMORY_ALL_TYPES,
    GPU_MEMORY_TYPE_CUDA_GRAPH,
    GPU_MEMORY_TYPE_KV_CACHE,
    GPU_MEMORY_TYPE_WEIGHTS,
)
from sglang.srt.managers.io_struct import (
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
    DestroyWeightsUpdateGroupReqInput,
    DestroyWeightsUpdateGroupReqOutput,
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
    UpdateWeightsFromIPCReqInput,
    UpdateWeightsFromIPCReqOutput,
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
    CheckWeightsReqInput,
    CheckWeightsReqOutput,
)logger = logging.getLogger(__name__)class SchedulerUpdateWeightsMixin:
    @staticmethod
    def flush_cache_after_weight_update(
        self: "SchedulerWeightUpdaterManager", recv_req
    ) -> None:
        if recv_req.flush_cache:
            flush_cache_success = self.flush_cache(
                empty_cache=recv_req.torch_empty_cache
            )
            assert flush_cache_success, "Cache flush failed after updating weights"
​
    @staticmethod
    def update_weights_from_disk(
        self: "SchedulerWeightUpdaterManager", recv_req: UpdateWeightFromDiskReqInput
    ):
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
        tp_success = success
        if success and self.draft_worker is not None:
            success, message = self.draft_worker.update_weights_from_disk(recv_req)
        if tp_success:
            SchedulerUpdateWeightsMixin.flush_cache_after_weight_update(self, recv_req)
        if not success:
            logger.error(message)
        return UpdateWeightFromDiskReqOutput(success, message, 0)
​
    # 其他方法类似
python/sglang/srt/managers/scheduler_components/weight_updater.py dependency-wiring

这是迁移目标文件,SchedulerWeightUpdaterManager 类获得了所有权重更新的实例方法,并补充了必要的导入。此类替代了原来的 mixin,成为权重更新功能的唯一入口。

# file: weight_updater.py ( 迁移后 )
# SchedulerWeightUpdaterManager 现在拥有所有权重更新方法,
# 它们从 @staticmethod 转换为普通实例方法。import logging
import traceback
from typing import Any, Callable, Tuple # Tuple 实际上在类型修正后不再使用import torch
from sglang.srt.managers.io_struct import (
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
    DestroyWeightsUpdateGroupReqInput,
    DestroyWeightsUpdateGroupReqOutput,
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
    UpdateWeightsFromIPCReqInput,
    UpdateWeightsFromIPCReqOutput,
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
    CheckWeightsReqInput,
    CheckWeightsReqOutput,
)logger = logging.getLogger(__name__)@dataclass(kw_only=True, slots=True)
class SchedulerWeightUpdaterManager:
    # 原有字段
    tp_worker: Any
    draft_worker: Any
    tp_cpu_group: Any
    memory_saver_adapter: Any
    flush_cache: Callable[..., bool]
    is_fully_idle: Callable[..., bool]
    offload_tags: set = field(default_factory=set)
    stashed_model_static_state: Any = None
​
    def flush_cache_after_weight_update(self, recv_req) -> None:
        if recv_req.flush_cache:
            flush_cache_success = self.flush_cache(
                empty_cache=recv_req.torch_empty_cache
            )
            assert flush_cache_success, "Cache flush failed after updating weights"
​
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
        tp_success = success
        if success and self.draft_worker is not None:
            success, message = self.draft_worker.update_weights_from_disk(recv_req)
        if tp_success:
            self.flush_cache_after_weight_update(recv_req)
        if not success:
            logger.error(message)
        return UpdateWeightFromDiskReqOutput(success, message, 0)
​
    # 注意:update_weights_from_distributed 的返回类型标注为 Tuple[bool, str] 与实际返回的 UpdateWeightsFromDistributedReqOutput 不符(review 指出)。
    def update_weights_from_distributed(
        self, recv_req: UpdateWeightsFromDistributedReqInput
    ) -> Tuple[bool, str]: # 应改为 UpdateWeightsFromDistributedReqOutput
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            self.flush_cache_after_weight_update(recv_req)
        else:
            logger.error(message)
        return UpdateWeightsFromDistributedReqOutput(success, message)

评论区精华

未使用的 Tuple 导入 style

gemini-code-assist[bot] 指出在修正返回类型后,Tuple 导入将不再使用,建议移除。

结论:建议未被作者采纳或回复,导入仍保留。 · 待处理

update_weights_from_distributed 返回类型标注错误 正确性

方法声明返回 Tuple[bool, str] 但实际返回 UpdateWeightsFromDistributedReqOutput,gemini-code-assist[bot] 指出应修改为正确的类型。

结论:未被修复,PR 已合并。 · 待处理

resume_memory_occupation 中集合删除建议使用 discard 正确性

使用 self.offload_tags.remove(tag) 若 tag 不存在会抛出 KeyError,gemini-code-assist[bot] 建议改为 discard 以安全删除。

结论:未被采纳,PR 已合并。 · 待处理

风险与影响

  1. 类型标注错误:update_weights_from_distributed 返回类型标注为 Tuple[bool, str] 与实际不符,可能误导调用方。
  2. 潜在的 KeyError:resume_memory_occupationoffload_tags 调用 remove,若标签不存在会抛异常,建议改为 discard
  3. 测试尚未覆盖:本次迁移未添加新测试,存在回归风险。
  4. 但方法体逐行一致(PR body 声明的 byte-identical),因此引入新 bug 的概率较低。

用户:无直接影响。
系统:调度器继承链缩短,组件职责更清晰;请求分发逻辑简化(消除 lambda 包装)。
团队:降低了 scheduler.py 的复杂度,后续维护和测试更容易。
影响范围:中等,涉及调度器核心的权重更新通路,但代码等价,回归可能性较小。

类型标注错误 集合删除风险 测试缺失

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论