Prhub

#41499 [Performance] Make safetensors checkpoint prefetch settings configurable

原始 PR 作者 mmangkad 合并时间 2026-05-10 23:55 文件变更 4 提交数 7 评论 7 代码增减 +87 / -15

执行摘要

使 safetensors 检查点预取参数可配置

之前 safetensors 预取使用硬编码值(8 线程、16MB 块),无法适应不同存储系统(如 NFS、本地 SSD)的调优需求。该 PR 使其可配置,同时保持默认值不变。

值得精读,尤其是参数化设计模式和并发原语替换的讨论。建议后续补充单元测试和性能基准数据。

讨论亮点
  • Semaphore 冗余:gemini-code-assist 指出 asyncio.SemaphoreThreadPoolExecutor 已限制并发而冗余,作者在提交中已将其移除。
  • 行为改变疑虑:arpera 担心将 asyncio.to_thread 改为 ThreadPoolExecutor 可能改变预取与模型加载的并行语义。mmangkad 回应解释预取仍在后台线程运行,_prefetch_all_checkpoints 立即返回,且新实现更显式可控。
  • 性能数据要求:arpera 要求上传实测性能数据(硬件、存储类型及对照表),以证明无回归并提供调优依据。目前作者尚未提供公开数据。

实现拆解

  1. 配置层 (vllm/config/load.py):新增模块级常量 DEFAULT_SAFETENSORS_PREFETCH_NUM_THREADS = 8DEFAULT_SAFETENSORS_PREFETCH_BLOCK_SIZE = 16 * 1024 * 1024,并在 LoadConfig 中添加两个 Pydantic 字段 safetensors_prefetch_num_threads(ge=1)和 safetensors_prefetch_block_size(ge=1)。
  2. CLI 接口 (vllm/engine/arg_utils.py):在 EngineArgs 中增加同名属性,并在 add_cli_args 中注册 --safetensors-prefetch-num-threads--safetensors-prefetch-block-size 参数,同时将 safetensors_prefetch_block_size 标记为 human_readable_int 以支持可读性格式。create_load_config 将值传递给 LoadConfig 构造函数。
  3. 核心预取逻辑 (vllm/model_executor/model_loader/weight_utils.py):
    • _prefetch_checkpoint 新增 block_size 参数(默认为常量)并增加有效性校验。
    • _prefetch_all_checkpoints 新增 num_prefetch_threadsblock_size 参数,将内部 asyncio.Semaphore 机制替换为显式的 ThreadPoolExecutor(max_workers=num_prefetch_threads),并通过 loop.run_in_executor 提交预取任务,从而更精确控制并发度。
    • 移除了冗余的 asyncio.Semaphore,因为 ThreadPoolExecutor 已限制并发。
  4. 向下传递 (vllm/model_executor/model_loader/default_loader.py):在 _get_weights_iterator 中,当调用 safetensors_weights_iterator 时传入 safetensors_prefetch_num_threadssafetensors_prefetch_block_size
    测试配套:本次变更未新增测试文件,但作者声称在 B300 和 GB300 上手动验证了默认行为保持兼容,且自定义参数能被正确使用。
文件 模块 状态 重要度
vllm/model_executor/model_loader/weight_utils.py 模型加载 modified 8.37
vllm/engine/arg_utils.py 引擎配置 modified 6.6
vllm/config/load.py 配置模型 modified 5.8
vllm/model_executor/model_loader/default_loader.py 模型加载 modified 5.4

关键符号

_prefetch_checkpoint _prefetch_all_checkpoints prefetch_one add_cli_args create_load_config _get_weights_iterator

关键源码片段

vllm/model_executor/model_loader/weight_utils.py core-logic

核心预取逻辑的变更,包括函数签名扩展、并发模型重构(ThreadPoolExecutor)和参数校验。

# vllm/model_executor/model_loader/weight_utils.pyDEFAULT_SAFETENSORS_PREFETCH_BLOCK_SIZE = 16 * 1024 * 1024 # 16MiB, 默认块大小
DEFAULT_SAFETENSORS_PREFETCH_NUM_THREADS = 8 # 默认线程数def _prefetch_checkpoint(
    file_path: str,
    block_size: int = DEFAULT_SAFETENSORS_PREFETCH_BLOCK_SIZE,
) -> None:
    """将检查点文件预读到操作系统页面缓存中。    以指定块大小读取文件,使内核在 worker 加载前缓存页面。
    """
    if block_size < 1:
        raise ValueError("safetensors prefetch block size must be >= 1")
    with open(file_path, "rb") as f:
        while f.read(block_size):
            passdef _prefetch_all_checkpoints(
    sorted_files: list[str],
    num_prefetch_threads: int = DEFAULT_SAFETENSORS_PREFETCH_NUM_THREADS,
    block_size: int = DEFAULT_SAFETENSORS_PREFETCH_BLOCK_SIZE,
) -> None:
    """在后台线程中启动检查点文件的预取。"""
    if num_prefetch_threads < 1:
        raise ValueError("safetensors prefetch num threads must be >= 1")
    if block_size < 1:
        raise ValueError("safetensors prefetch block size must be >= 1")
​
    # 分布式场景下每个 rank 预取其负责的一部分文件
    if torch.distributed.is_initialized():
        rank = torch.distributed.get_rank()
        world_size = torch.distributed.get_world_size()
    else:
        rank = 0
        world_size = 1
    paths_to_prefetch = sorted_files[rank::world_size]
    total_for_rank = len(paths_to_prefetch)
​
    async def _prefetch_all() -> None:
        loop = asyncio.get_running_loop()
        completed = 0
        next_log_pct = 10
​
        async def prefetch_one(
            path: str,
            executor: concurrent.futures.ThreadPoolExecutor,
        ) -> None:
            nonlocal completed, next_log_pct
            try:
                # 使用 ThreadPoolExecutor 控制并发,替代之前的 asyncio.Semaphore
                await loop.run_in_executor(
                    executor, _prefetch_checkpoint, path, block_size
                )
                completed += 1
                if total_for_rank > 0 and next_log_pct <= 100:
                    pct = 100 * completed / total_for_rank
                    if pct >= next_log_pct:
                        logger.info(
                            "Prefetching checkpoint files: %d%% (%d/%d)",
                            int(pct), completed, total_for_rank)
                        next_log_pct += 10
            except Exception:
                logger.warning(
                    "Failed to prefetch checkpoint file %r.", path, exc_info=True)
​
        # 显式创建线程池,size 由参数控制,更精确
        with concurrent.futures.ThreadPoolExecutor(
            max_workers=num_prefetch_threads
        ) as executor:
            await asyncio.gather(
                *(prefetch_one(p, executor) for p in paths_to_prefetch)
            )
​
    def _run_prefetch() -> None:
        start = time.perf_counter()
        asyncio.run(_prefetch_all())
        elapsed = time.perf_counter() - start
        logger.info("Prefetching %d checkpoint files took %.2f seconds.",
                    total_for_rank, elapsed)
​
    # 在守护线程中启动,不阻塞主流程
    thread = threading.Thread(target=_run_prefetch, daemon=True)
    thread.start()

评论区精华

asyncio.Semaphore 冗余 设计

gemini-code-assist 指出 asyncio.Semaphore 因 ThreadPoolExecutor 已限制并发而冗余。

结论:作者已移除 Semaphore,采用纯 ThreadPoolExecutor 控制并发。 · 已解决

ThreadPoolExecutor 行为变化疑虑 设计

arpera 担心从 asyncio.to_thread 改为 ThreadPoolExecutor 可能改变预取与模型加载的并行语义。

结论:mmangkad 解释预取仍在后台守护线程运行,新实现更显式可控,且进行了回归测试。 · 已解决

请求性能数据 性能

arpera 要求上传硬件、存储类型和性能数据以证明无回归并提供调优参考。

结论:作者未提供公开数据,PR 已合入但未满足该请求。 · unresolved

风险与影响

  1. 核心路径变更风险weight_utils.py 是模型加载核心路径,替换并发原语可能引入潜在调度变化,尽管作者声称行为等价,但缺少单元测试覆盖。
  2. 缺少测试覆盖:没有新增自动化测试验证新参数是否正确传递及默认行为兼容。
  3. 整数溢出风险block_sizenum_prefetch_threads 添加了 ge=1 校验,但未限制上限,极端值可能导致内存压力或性能问题。

对用户:提供灵活调优能力,尤其适合使用 NFS 等共享存储的场景。对系统:默认行为完全向后兼容,用户无需修改现有配置。对团队:新增两个 CLI 参数需要维护文档和向后兼容性。

核心路径变更 缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论