执行摘要
- 一句话:修复 Blackwell 上 resume 时因 inference_mode 导致的崩溃
- 推荐动作:值得快速合并的低风险 bugfix。对于维护者,可关注后续是否有其他涉及 inference tensor 的类似场景。
功能与动机
修复 Blackwell/B200 上模型 resume 时的崩溃。PR body 明确指出:RuntimeError: Inplace update to inference tensor outside InferenceMode is not allowed. 这是由于 warmup 时 RotaryEmbedding.cos_sin_cache 被替换为 inference tensor,后续 resume 时外部调用 _import_static_state 进行 inplace 写入而导致的问题。
实现拆解
- 定位问题文件及函数:在
python/sglang/srt/managers/scheduler_update_weights_mixin.py 的 _import_static_state 函数中,对模型 buffer 进行 inplace 写入(self_named_buffers[name][...] = tensor),当 buffer 是 inference tensor 时,若不在 torch.inference_mode() 上下文中则会报错。
- 最小化修复:将函数体内的两行代码整体缩进,包裹进入
with torch.inference_mode(): 上下文管理器。
- 不改变外部行为:该函数本就在权重恢复流程中被调用,外部调用方无需修改,确保兼容性。
关键文件:
python/sglang/srt/managers/scheduler_update_weights_mixin.py(模块 权重恢复;类别 source;类型 core-logic;符号 _import_static_state): 修复的核心文件,包含 _import_static_state 函数,通过添加 torch.inference_mode() 上下文管理器修复 Blackwell 上 resume 崩溃。
关键符号:_import_static_state
关键源码片段
python/sglang/srt/managers/scheduler_update_weights_mixin.py
修复的核心文件,包含 _import_static_state 函数,通过添加 torch.inference_mode() 上下文管理器修复 Blackwell 上 resume 崩溃。
# python/sglang/srt/managers/scheduler_update_weights_mixin.py
# ... 前略
def _export_static_state(model):
return dict(
buffers=[
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
]
)
def _import_static_state(model, static_params):
# 关键修复:将整个 buffer 写入操作包裹在 inference_mode 下
# 这样即使 buffer 是 inference tensor(在 Blackwell 上因 warmup 产生),
# inplace 写入也能正常执行,避免 RuntimeError。
with torch.inference_mode():
self_named_buffers = dict(model.named_buffers())
for name, tensor in static_params["buffers"]:
self_named_buffers[name][...] = tensor
评论区精华
无 review 评论讨论,PR 被直接批准。只有一个 reviewer(hnyls2002)给出了 APPROVED 状态,但未留下文字评论。可能因为改动简单且定位清晰。
风险与影响
- 风险:风险极低:改动仅在一个文件中添加了
with torch.inference_mode(): 上下文管理器,逻辑等价性明确。核心风险是 _import_static_state 内部是否有其他操作期望在非 inference 模式下运行——但函数内仅有 buffer 的 inplace 写入,与 inference mode 无冲突。
- 影响:影响范围:仅影响调用
_import_static_state 的 resume 流程(resume_memory_occupation),且仅对 Blackwell 系列 GPU 有效。对其他 GPU 无感知(inference mode 下同写同读,行为一致)。
- 风险标记:极低风险
关联脉络
参与讨论