Prhub

#26937 Add per-rank staggered weight loading for improved TP I/O concurrency

原始 PR 作者 power-more 合并时间 2026-06-03 11:25 文件变更 4 提交数 4 评论 5 代码增减 +34 / -10

执行摘要

TP 权重加载排序与交错 I/O 优化

PR body 指出两个问题:

1) 内部迭代器里的 sorted() 调用覆盖了 SGLANG_SORT_WEIGHT_FILES 的设置,且仅有 safetensors 和 multi-thread 两条路径有排序,其他路径不一致;
2) TP>1 时所有 rank 以完全相同顺序读取文件,无法利用并发 I/O 优势。通过统一排序控制和使用交错排序来提升多 rank 的 I/O 并发度。

值得仔细阅读 loader.py 中交错逻辑的实现,并确认默认行为变更已广而告之。建议在 test/registered 中添加一个加载相关测试,覆盖 k=-1, 0, 1, 2 等场景,确保回归捕获。

讨论亮点

无实质性 review 讨论。ShangmingCai 审核并批准,仅提交了两次 lint 修复提交。

实现拆解

  1. 环境变量升级(environ.py:将 SGLANG_SORT_WEIGHT_FILESEnvBool(False) 改为 EnvInt(0),新增三种语义:-1(不排序)、0(仅排序,默认)、k>0(排序并按因子 k 交错)。
  2. 集中排序控制(loader.py:在 _prepare_weights() 中移除旧的 if envs.SGLANG_SORT_WEIGHT_FILES.get(): hf_weights_files.sort(),替换为根据 k 值和 TP 信息的集中排序+交错逻辑。当 k>0 时,将文件列表按 (tp_size * k) 分组,每个 rank 在其组内循环偏移 (tp_rank * k),达成交错。
  3. 移除迭代器内硬编码排序(weight_utils.py:从 safetensors_weights_iteratorbuffered_multi_thread_safetensors_weights_iterator 中删除 sorted_files = sorted(hf_weights_files),改用传入的 hf_weights_files(此时已由 loader 排序/交错)。仅保留 _prefetch_all_checkpoints 内部对 sorted(hf_weights_files) 的调用,确保预取顺序一致性。
  4. 对齐文本编码器加载(text_encoder_loader.py:将 text_encoder_loader_prepare_weights 的条件从 if envs.SGLANG_SORT_WEIGHT_FILES.get() 改为 if envs.SGLANG_SORT_WEIGHT_FILES.get() >= 0,在支持排序的同时避免不必要的交错(因为文本编码器无 TP 拆分)。
文件 模块 状态 重要度
python/sglang/srt/model_loader/loader.py 模型加载 modified 6.56
python/sglang/srt/model_loader/weight_utils.py 模型加载 modified 6.08
python/sglang/srt/environ.py 配置 modified 5.31
python/sglang/multimodal_gen/runtime/loader/component_loaders/text_encoder_loader.py 文本编码器 modified 4.67

关键符号

_prepare_weights safetensors_weights_iterator buffered_multi_thread_safetensors_weights_iterator

关键源码片段

python/sglang/srt/model_loader/loader.py data-contract

核心修改:将排序和交错逻辑集中到 `_prepare_weights`,取代旧的单行 sort 调用;新增基于 TP 大小的分组交错算法。

def _prepare_weights(self, source, revision, fall_back_to_pt):
    # ... 过滤、去重等逻辑 ...
    # Sort and optionally stagger weight files (see SGLANG_SORT_WEIGHT_FILES).
    # k=-1: no sort; k=0: sort only; k>0: sort + stagger by (tp_rank * k).
    k = envs.SGLANG_SORT_WEIGHT_FILES.get()
    if k >= 0:
        hf_weights_files.sort()
        if k > 0:
            tp_size = get_tensor_model_parallel_world_size()
            if tp_size > 1:
                tp_rank = get_tensor_model_parallel_rank()
                group_size = tp_size * k
                staggered: List[str] = []
                for i in range(0, len(hf_weights_files), group_size):
                    group = hf_weights_files[i : i + group_size]
                    n = len(group)
                    # 每个 rank 在组内循环偏移 (tp_rank * k) 个位置
                    staggered.extend(group[(j + tp_rank * k) % n] for j in range(n))
                hf_weights_files = staggered
    return hf_folder, hf_weights_files, use_safetensors
python/sglang/srt/model_loader/weight_utils.py data-contract

删除迭代器内硬编码的 sorted(),改为直接使用传入的文件列表;预取时仍保留 sorted() 以保证一致。

def safetensors_weights_iterator(
    hf_weights_files: List[str],
    disable_mmap: bool = False,
    prefetch: bool = False,
    prefetch_num_threads: int = 4,
    drop_cache_after_load: bool = False,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model safetensor files."""
    enable_tqdm = (
        not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
    )
​
    # 预取时仍然使用 sorted() 以确保页面缓存被按序访问
    if prefetch and not disable_mmap:
        _prefetch_all_checkpoints(
            sorted(hf_weights_files), num_threads=prefetch_num_threads
        )
​
    for st_file in tqdm(
        hf_weights_files, # 直接使用已排序 / 交错的列表
        desc="Loading safetensors checkpoint shards",
        disable=not enable_tqdm,
        bar_format=BAR_FORMAT,
        position=tqdm._get_free_pos(),
    ):
        # ... 加载逻辑 ...
python/sglang/srt/environ.py core-logic

定义新的整型环境变量,注释说明各取值语义(-1/0/k>0),是功能配置的入口。

class Envs:
    # fmt: off
    # Model & File Download
    SGLANG_USE_MODELSCOPE = EnvBool(False)
    # Controls weight-file ordering for load-time I/O optimization.
    # -1 : no sorting, no staggering; preserves original file order.
    # 0 : sort files only; maximizes ordering but may reduce cross-rank I/O concurrency.
    # k>0: sort files and stagger per-rank order with factor k.
    # Files are processed in groups of (tp_size * k), and rank r starts each
    # group at offset (r * k), improving multi-rank I/O concurrency while
    # keeping access relatively ordered.
    SGLANG_SORT_WEIGHT_FILES = EnvInt(0)

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

  1. 回归风险:默认行为从 EnvBool(False)(不排序)变为 EnvInt(0)(排序),可能影响依赖原始文件顺序的旧部署。PR 中说明这是有意的,但需要在发布说明中强调。
  2. 兼容性:环境变量类型从布尔型变为整型,若存在旧配置直接写 SGLANG_SORT_WEIGHT_FILES=true 会导致解析失败。但 EnvInt 的 parse 会处理常见表达式(如 True 被转换为 1),风险可控。
  3. 性能回归:在缓存充足场景下,仅排序模式的加载时间比不排序差(155s vs 56s),但交错模式有助于缓解此问题。

用户角度:TP 模型部署的启动加载时间可改善,特别是共享文件系统场景下;用户可通过设置 SGLANG_SORT_WEIGHT_FILES=1 开启交错获得加速。系统角度:减少了权重加载阶段的文件顺序耦合,使多 rank 的 I/O 分布更均匀,有利于扩展到大 TP 规模。团队角度:统一的排序控制点降低了后续增加新加载路径时的认知负担,并在 text_encoder_loader 中展示了可复用的模式。

默认行为变更 缺少测试覆盖 缓存充足场景性能回退

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论