Prhub

#26861 [loader] Reduce transient allocations in NVFP4 MoE setup

原始 PR 作者 yinghai 合并时间 2026-06-04 12:13 文件变更 2 提交数 2 评论 4 代码增减 +68 / -50

执行摘要

预分配 NVFP4 MoE 权重张量避免内存碎片

PR提交信息指出:'Preallocate shuffled weight and scale tensors for TRTLLM FP4 MoE setup, and skip temporary blockscale swizzle placeholders when TRTLLM replaces them after weight loading. This avoids memory fragmentation and save a few GBs of HBM depending on the models.' 合并者merrymercy在提交消息中明确说明避免GPU内存碎片。

此PR值得精读,尤其是对内存在GPU上管理有优化兴趣的工程师。设计模式:避免临时分配列表再堆叠,而是预分配和重用缓冲区;条件跳过无关工作以减少内存峰值。

讨论亮点

合并者merrymercy直接批准并触发了rerun-test。AI 审核机器人给出了正面总结。无争议或未解决问题。

实现拆解

  1. 预分配输出张量utils.py):将原先的list + stack模式改为先创建 torch.empty_like 的连续大张量,然后在循环中直接赋值子 slice,避免每个专家一次临时分配。
  2. 引入可重用scratch buffer:新增 _alloc_scale_buffers 辅助函数,一次性分配 scale 输出张量和 permuted 输入的暂存缓冲区(scratch)。循环中使用 torch.index_select 将 permuted 数据写入 scratch,然后调用 nvfp4_block_scale_interleave 直接输出到目标 slice,避免 .contiguous() 调用。
  3. 条件跳过 blockscale swizzlemodelopt_quant.py):在 ModelOptFP4Config.create_weights 方法中,当 self.enable_flashinfer_trtllm_moe 为 True 时,将 w13_blockscale_swizzledw2_blockscale_swizzled 设为 None,而非直接创建 swizzle_blockscale 参数,因为 TRTLLM 会在 process_weights_after_loading 中替换该张量。
  4. 格式化:第二个提交运行了 black 格式化。
文件 模块 状态 重要度
python/sglang/srt/layers/quantization/utils.py 量化工具 modified 7.18
python/sglang/srt/layers/quantization/modelopt_quant.py 量化配置 modified 5.98

关键符号

_alloc_scale_buffers prepare_static_weights_for_trtllm_fp4_moe ModelOptFP4Config.create_weights

关键源码片段

python/sglang/srt/layers/quantization/utils.py core-logic

核心优化逻辑所在:重写 prepare_static_weights_for_trtllm_fp4_moe 函数,引入预分配和 scratch buffer,减少临时 GPU 分配。

def _alloc_scale_buffers(scales):
    # 获取每个 expert 的 scale 输入形状和元素数
    per_expert_shape = scales[0].view(torch.uint8).shape
    per_expert_numel = scales[0].numel()
    # 预分配整个输出张量 (num_experts, per_expert_numel) 和一个可复用的 scratch 缓冲区
    output = scales.new_empty((num_experts, per_expert_numel), dtype=torch.uint8)
    scratch = torch.empty(per_expert_shape, dtype=torch.uint8, device=scales.device)
    return output, scratch# 预分配 weight 和 scale 输出张量
# 原代码使用 list + torch.stack,每个 expert 都会产生一次临时分配
gemm1_weights_fp4_shuffled = torch.empty_like(gemm1_weights_fp4.view(torch.uint8))
gemm2_weights_fp4_shuffled = torch.empty_like(gemm2_weights_fp4.view(torch.uint8))
gemm1_scales_fp4_shuffled, g1s_scratch = _alloc_scale_buffers(gemm1_scales_linear_fp4)
gemm2_scales_fp4_shuffled, g2s_scratch = _alloc_scale_buffers(gemm2_scales_linear_fp4)for i in range(num_experts):
    # ... 获取 permute_indices 和 permute_sf_indices 的代码保持不变 ...
    # 直接写入预分配张量的第 i 个 slice,避免 append + contiguous
    gemm1_weights_fp4_shuffled[i] = gemm1_weights_fp4[i].view(torch.uint8)[permute_indices.to(...)]
​
    # 使用 index_select 将 permuted 数据写入 scratch,然后 interleave 输出到目标 slice
    torch.index_select(
        gemm1_scales_linear_fp4[i].view(torch.uint8),
        0,
        permute_sf_indices.to(...),
        out=g1s_scratch,
    )
    gemm1_scales_fp4_shuffled[i] = nvfp4_block_scale_interleave(g1s_scratch)
​
    # 对 w2 同理
    gemm2_weights_fp4_shuffled[i] = gemm2_weights_fp4[i].view(torch.uint8)[permute_indices.to(...)]
    torch.index_select(...)
    gemm2_scales_fp4_shuffled[i] = nvfp4_block_scale_interleave(g2s_scratch)del g1s_scratch, g2s_scratch

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

主要风险在于逻辑等价性:原先list+stack后contiguous(),新方式直接赋值slice可能改变内存连续性假设,但预分配的大张量本身是连续的,slice也是连续部分,且TRTLLM内核期望连续的uint8张量,风险较低。当 enable_flashinfer_trtllm_moe 为 False 时,行为保持不变;只有 True 时才跳过 swizzle,不会影响非 TRTLLM 路径。潜在性能风险:如果scratch buffer在不同CUDA流上导致同步问题,但所有操作都在同一流上且循环串行,应该没问题。缺少直接单元测试覆盖,但已有集成测试覆盖FP4模型加载。

影响范围限于使用TRTLLM FP4 MoE量化的大模型(如DeepSeek V3)。减少内存碎片可能使更大规模的模型部署或更大的batch成为可能。无用户API变化,属于后台优化,对开发者和运维透明。

缺少单元测试覆盖 TRTLLM 特定优化

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论