Prhub

#23625 Flux2 nvfp4 quantization correctness on Blackwell (B200)

原始 PR 作者 Johnsonms 合并时间 2026-05-02 09:57 文件变更 8 提交数 17 评论 14 代码增减 +128 / -42

执行摘要

修复 FLUX.2 NVFP4 在 B200 上的量化正确性

在 B200 硬件上,FLUX.2-dev-NVFP4 量化模型输出全白图像(均值≈252,标准差≈2),而 BF16 版本输出正常。经调查,上游 main 分支存在三个正确性 bug,共同导致 NVFP4 路径失效。详见 PR #23625 body。

值得精读。关注 process_weights_after_loading 的条件化设计、per-GEMM z-score 调试方法。建议未来建立可配置命名映射机制。

讨论亮点

在 Review 中,BBuf 指出 swap_weight_nibbles 在 from_config 中已默认 True,作者移除显式设置。BBuf 询问是否仅改 modelopt_quant.py 即可修复,作者确认需三个修复同时应用。OrangeRedeng 质疑 mlp.py 和 wanvideo.py 的 prefix 更名影响 Wan 模型,该兼容性问题未彻底解决。

实现拆解

  1. 修复 Input Scale 缺失:在 flux_2_nvfp4.py 中为 16 个 FP4 量化 txt_mlp 层补全 input_scale 参数。
  2. 条件化 TMA Scale 排列:在 modelopt_quant.py 的 process_weights_after_loading 中根据后端类型决定是否进行 blockwise interleave。
  3. 修复加载器回退:在 quantization_utils.py 修正排除模块映射,在 mlp.py 和 wanvideo.py 中统一前缀命名。
  4. 增强检查点恢复:在 utils.py 新增 _try_redownload_missing_shards 函数,自动修复不完整检查点。
  5. 兼容性修复与 CI 恢复:在 component_loader.py 增加 RobertaProcessing 回退,恢复 B200 CI。
文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/loader/utils.py 加载工具 modified 7.42
python/sglang/multimodal_gen/runtime/layers/quantization/modelopt_quant.py 量化层 modified 6.4
python/sglang/multimodal_gen/runtime/loader/component_loaders/component_loader.py 组件加载 modified 6.56
python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py 模型定义 modified 6.31
python/sglang/multimodal_gen/runtime/utils/quantization_utils.py 量化工具 modified 5.11
python/sglang/multimodal_gen/runtime/layers/mlp.py MLP 层 modified 5.11

关键符号

_try_redownload_missing_shards _list_safetensors_files process_weights_after_loading create_weights _build_nvfp4_config_from_safetensors_files load_customized

关键源码片段

python/sglang/multimodal_gen/runtime/layers/quantization/modelopt_quant.py data-contract

修复 TMA scale 排列条件化,确保 cuDNN 后端获得正确的 row-major scales

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    # ... 前面的 alpha, input_scale_inv 计算
    scales = layer.weight_scale
    scale_ndim = scales.ndim
    if scale_ndim == 2:
        scales = scales.unsqueeze(0)
    assert scales.ndim == 3
    B, M, K = scales.shape
    M_padded = round_up(M, 128)
    K_padded = round_up(K, 4)
    padded_scales = torch.zeros((B, M_padded, K_padded), dtype=scales.dtype)
    padded_scales[:B, :M, :K] = scales
​
    # 关键变更:仅在 CUTLASS 路径下应用 TMA 排列
    _, flashinfer_backend = _get_fp4_gemm_op()
    if flashinfer_backend is None:
        # CUTLASS (sgl_kernel) 路径:blockwise interleave 适应 TMA 布局
        padded_scales = padded_scales.reshape(
            B, M_padded // 128, 4, 32, K_padded // 4, 4
        )
        padded_scales = padded_scales.permute(0, 1, 4, 3, 2, 5)
​
    padded_scales = padded_scales.contiguous().cuda()
    padded_scales = (
        padded_scales.reshape(M_padded, K_padded)
        if scale_ndim == 2
        else padded_scales.reshape(B, M_padded, K_padded)
    )
    copy_or_rebind_param(layer, 'weight_scale_interleaved', padded_scales)

评论区精华

swap_weight_nibbles 默认值冗余 设计

BBuf 指出显式设置 swap_weight_nibbles=True 多余,因为 from_config() 已默认 True。

结论:作者移除该显式设置,E2E 行为无变化。 · 已解决

修复范围是否仅需 modelopt_quant.py 正确性

BBuf 询问是否只改 modelopt_quant.py 就能修复全部问题。

结论:作者验证后确认需要三个修复同时应用。 · 已解决

prefix 命名兼容性影响 Wan 模型 question

OrangeRedeng 质疑 mlp.py 和 wanvideo.py 的 prefix 更名是否必要,是否破坏 Wan 等其他模型的量化加载(CI 显示 Wan 测试失败)。

结论:作者表示更名是为了匹配排除模块命名,当前 PR 已合并,后续需多架构适配。 · 待处理

风险与影响

1) prefix 修改可能影响其他模型(如 Wan),已由 CI 失败证实。
2) 自动修复依赖 hf_hub_download 网络,离线环境失败。
3) 仅 B200 验证,Hopper/Ada 未回归。
4) 无新增单元测试覆盖自动修复路径。

正面:FLUX.2-dev-NVFP4 在 B200 上恢复正常图像质量,检查点自动修复提高鲁棒性。负面:Wan 等模型可能因 prefix 变更而量化加载失败,需后续适配。

核心路径变更(量化路径) 跨模型兼容性(Wan prefix 影响) 无新增单测覆盖自动修复 仅验证 B200 硬件 离线环境自动修复依赖 HuggingFace Hub 网络

关联 Issue

#3 Add install with pip

完整报告

参与讨论