Prhub

#24195 Fix flashinfer autotune oom glm51

原始 PR 作者 kpham-sgl 合并时间 2026-06-03 14:28 文件变更 2 提交数 7 评论 6 代码增减 +35 / -1

执行摘要

修复 FlashInfer 自动调优时因 lm_head 导致 OOM

FlashInfer 自动调优时,_dummy_run 会触发完整的 LogitsProcessor.forward,其中的 lm_head + tensor-parallel all-gather 会分配大量显存,导致 OOM。PR body 指出这是从 #23796 搬运过来的修复,并引用了 OOM CI 日志。

此 PR 设计简洁,修复明确,值得阅读以了解如何通过上下文管理器在特定路径跳过计算,避免 OOM。

讨论亮点

Fridge003 提出是否可以直接从 flashinfer 的 autotune 上下文中获取全局标志,避免自建 context manager。kpham-sgl 回复称 flashinfer 有一个相关的 PR(#3396)但尚未合并,目前没有替代方案,因此保留当前实现。

实现拆解

  1. logits_processor.py 中新增上下文管理器 autotune_dummy_run_mode:设置全局标志 _in_autotune_dummy_run = True,并在 LogitsProcessor.forward 开头检查该标志,若为 True 则直接返回空输出(LogitsProcessorOutput(next_token_logits=None)),从而跳过所有 lm_head 计算和 TP all-gather。

  2. model_runner.py_flashinfer_autotune 方法中启用该上下文管理器:导入 autotune_dummy_run_mode,并在 torch.inference_mode()autotune() 上下文之上嵌套使用,确保 dummy run 期间自动跳过 lm_head。

  3. 配套调整:新增 import from contextlib import contextmanager,在 _in_autotune_dummy_run 旁添加详细注释说明动机和 OOM 机制。

文件 模块 状态 重要度
python/sglang/srt/layers/logits_processor.py logits 处理器 modified 7.64
python/sglang/srt/model_executor/model_runner.py 模型运行器 modified 6.17

关键符号

get_in_autotune_dummy_run autotune_dummy_run_mode LogitsProcessor.forward ModelRunner._flashinfer_autotune

关键源码片段

python/sglang/srt/layers/logits_processor.py core-logic

核心变更:新增 `_in_autotune_dummy_run` 全局标志、`get_in_autotune_dummy_run()` 和 `autotune_dummy_run_mode()` 上下文管理器,并在 `forward()` 方法开头添加早期返回逻辑以跳过 lm_head 计算。

# logits_processor.py ( 片段 )# 当该标志为 True 时,LogitsProcessor.forward 返回空输出并跳过
# LM head + tensor-parallel all-gather。FlashInfer autotune 只
# 需要 profile attention/MoE/GEMM 内核,LM-head all-gather 是
# 多余的计算,并且在 DP attention + 紧 mem_fraction_static 下
# 其 [batch * dp_size, vocab] 输出会 OOM。
_in_autotune_dummy_run = False
​
​
def get_in_autotune_dummy_run() -> bool:
    return _in_autotune_dummy_run
​
​
@contextmanager
def autotune_dummy_run_mode():
    """上下文管理器,在 FlashInfer autotune 期间启用 dummy run 模式。"""
    global _in_autotune_dummy_run
    _in_autotune_dummy_run = True
    try:
        yield
    finally:
        _in_autotune_dummy_run = False
​
​
class LogitsProcessor(nn.Module):
    # ...
    def forward(self, ...):
        # ...
        # 检查是否处于 autotune dummy run:若是则跳过所有 LM head 计算
        if _in_autotune_dummy_run:
            return LogitsProcessorOutput(next_token_logits=None)
        # 后续正常的 logits 处理逻辑 ...
python/sglang/srt/model_executor/model_runner.py data-contract

变更点:在 `_flashinfer_autotune` 方法中导入并使用 `autotune_dummy_run_mode()` 上下文管理器,嵌套在现有 `torch.inference_mode()` 和 `autotune()` 上下文中,使得 dummy run 自动跳过 lm_head 计算。

# model_runner.py ( 片段 )def _flashinfer_autotune(self):
    """Run flashinfer autotune."""
    from flashinfer.autotuner import autotune
    from sglang.srt.layers.logits_processor import autotune_dummy_run_mode
​
    cache_path = self._flashinfer_autotune_cache_path()
    # ... 缓存逻辑 ...
​
    self.forward_stream.wait_stream(torch.cuda.current_stream())
    with torch.get_device_module(self.device).stream(self.forward_stream):
        with (
            torch.inference_mode(),
            autotune(True, cache=str(autotune_cache)),
            autotune_dummy_run_mode(), # 新增:跳过 LM head
        ):
            self._dummy_run(batch_size=self.req_to_token_pool.size)
    torch.cuda.current_stream().wait_stream(self.forward_stream)
    logger.info("FlashInfer autotune completed.")

评论区精华

是否可使用 flashinfer 自带的全局标志替代自建 context manager? 设计

Fridge003 询问能否从 flashinfer.autotune 中获取全局标志,从而避免定义自定义 context manager。kpham-sgl 回复称 flashinfer 有一个相关 PR(#3396)但尚未合并,目前没有替代方案。

结论:由于 flashinfer 尚未提供对应接口,保留当前自定义 context manager 实现。 · 已解决

风险与影响

变更范围小(仅两个文件),且只在 autotune 路径下生效,不影响正常推理路径。风险较低。但需注意 LogitsProcessor.forward 的早期返回必须放在所有分支之前,当前实现已满足,但未来若增加新的早期分支需要留意顺序。

仅影响 FlashInfer 自动调优的 dummy run 阶段,解决特定模型(GLM51)在 DP attention + 紧内存预算下的 OOM 问题。无其他影响。

核心路径变更 缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论