Prhub

#25523 [Diffusion] Default NVFP4 backend to FlashInfer TRTLLM

原始 PR 作者 BBuf 合并时间 2026-05-25 18:14 文件变更 11 提交数 2 评论 3 代码增减 +141 / -57

执行摘要

默认 NVFP4 后端切换为 FlashInfer TRTLLM,支持 swizzled scale 布局

PR 跟随 #25857 中 Wan2.2 NVFP4 checkpoint 的落地方案,进一步将 diffusion ModelOpt NVFP4 的默认后端从 cudnn 改为 flashinfer_trtllm,因为 trtllm 在 Blackwell GPU 上对常见形状(如 FLUX.1-dev、Wan2.2)具有显著性能优势(见 PR 性能表格:FLUX.1-dev 生成时间降低约 34%)。此外,官方 FLUX.2 NVFP4 导出的 weight scale 是 FlashInfer/CUTLASS-swizzled 布局,而 SGLang 转换的 repos 保持线性布局,需要新增布局追踪和转换逻辑以正确加载这类 checkpoint。

该 PR 值得精读,特别是 scale layout 的处理设计和默认后端的切换策略。建议重点关注 _swizzled_nvfp4_scales_to_linear 的实现以及 review 中暴露的潜在风险,未来可能需统一后端间的 scale 处理路径。

讨论亮点

review 评论:scale un-swizzling 逻辑应移至后端通用位置

gemini-code-assist[bot] 指出,当前 swizzled scale 的转换在 process_weights_after_loading 中只在进入 trtllm 后端分支前执行,但转换后的 scales 仅用于 trtllm 路径;对于其他后端(cudnncutlasssgl_kernel 回退),layer.weight_scale 仍保留 swizzled 布局,可能导致错误。建议将转换逻辑提前到所有后端共享部分,并更新 layer.weight_scale 参数。该评论未得到显式回复或关闭,PR 仍然合并。

实现拆解

1. 变更默认后端选择逻辑

修改 python/sglang/multimodal_gen/runtime/platforms/cuda.py 中的 get_modelopt_flashinfer_fp4_backend,将默认后端从 cudnn (Blackwell) 或 auto 统一改为 trtllm。同时重写 get_modelopt_fp4_gemm_op,移除旧版的有条件优先逻辑,直接优先尝试 flashinfer.mm_fp4,失败则回退到 sgl_kernel.cutlass_scaled_fp4_mm

2. 新增 swizzled scale layout 支持

python/sglang/multimodal_gen/runtime/layers/quantization/modelopt_quant.py 中新增 _swizzled_nvfp4_scales_to_linear 函数,将 FlashInfer/CUTLASS 的 swizzled layout 转换回 row-major 线性布局。在 ModelOptFp4Config 中添加 checkpoint_weight_scale_layout 字段(默认 linear),并在 process_weights_after_loading 中根据该字段决定是否对 weight_scale 进行转换。

3. 配置合并与推断逻辑

修改 transformer_load_utils.py 中的 _merge_modelopt_fp4_configs,当两个配置的 checkpoint_weight_scale_layout 不同时优先保留非 linear 的值。修改 quantization_utils.py 中的 _build_nvfp4_config_from_safetensors,在检测到 packed QKV 时自动推断 layout 为 swizzled

4. 测试与文档

新增测试 test_flux2_swizzled_scale_checkpoint_flashinfer_trtllm_matches_cudnn,验证从 swizzled scale checkpoint 加载后,flashinfer_trtllm 输出与 cudnn 后端一致。更新 docs_new/docs/sglang-diffusion/quantization.mdxdocs/diffusion/quantization.md 以反映默认后端变更;更新 envs.py 和对应文档注释。

文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/layers/quantization/modelopt_quant.py 量化层 modified 7.41
python/sglang/multimodal_gen/runtime/platforms/cuda.py 平台层 modified 6.71
python/sglang/jit_kernel/tests/diffusion/test_diffusion_nvfp4_scaled_mm.py NVFP4 测试 modified 6.05
python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py 模型加载器 modified 5.9
python/sglang/multimodal_gen/runtime/utils/quantization_utils.py 量化工具 modified 5.85
python/sglang/multimodal_gen/envs.py 环境变量 modified 4.3
docs_new/docs/sglang-diffusion/quantization.mdx 文档(新) modified 3.67

关键符号

_swizzled_nvfp4_scales_to_linear get_modelopt_flashinfer_fp4_backend get_modelopt_fp4_gemm_op ModelOptFp4LinearMethod.process_weights_after_loading _merge_modelopt_fp4_configs _build_nvfp4_config_from_safetensors test_flux2_swizzled_scale_checkpoint_flashinfer_trtllm_matches_cudnn

关键源码片段

python/sglang/multimodal_gen/runtime/layers/quantization/modelopt_quant.py data-contract

核心变更:新增 swizzled scale 转换函数、配置字段和 weights 加载逻辑

# python/sglang/multimodal_gen/runtime/layers/quantization/modelopt_quant.py# 新增函数:将 FlashInfer/CUTLASS 的 swizzled FP4 scales 转换回 row-major 线性布局
# 用于加载官方 FLUX.2 等以 swizzled 格式存储的 checkpoint
def _swizzled_nvfp4_scales_to_linear(scales: torch.Tensor) -> torch.Tensor:
    scale_ndim = scales.ndim
    if scale_ndim == 2:
        scales = scales.unsqueeze(0) # 增加 batch 维度使处理统一
    assert scales.ndim == 3
    B, M, K = scales.shape
    M_padded = round_up(M, 128) # 对齐到 128 的倍数
    K_padded = round_up(K, 4) # 对齐到 4 的倍数
    if M != M_padded or K != K_padded:
        padded = torch.zeros((B, M_padded, K_padded), dtype=scales.dtype, device=scales.device)
        padded[:B, :M, :K] = scales
        scales = padded
    # 反 swizzle:reshape + permute 还原线性布局
    linear = scales.reshape(B, M_padded // 128, K_padded // 4, 32, 4, 4)
    linear = linear.permute(0, 1, 4, 3, 2, 5).contiguous()
    linear = linear.reshape(B, M_padded, K_padded)[:, :M, :K]
    return linear.squeeze(0) if scale_ndim == 2 else linear# 在 process_weights_after_loading 中,先根据配置转换 scale,再根据后端处理
scales = layer.weight_scale
if getattr(self.quant_config, 'checkpoint_weight_scale_layout', 'linear') == 'swizzled':
    scales = _swizzled_nvfp4_scales_to_linear(scales)
_, flashinfer_backend = _get_fp4_gemm_op()
if flashinfer_backend == 'trtllm':
    # trtllm 后端需要额外 padding 和 shuffle
    weight, _ = pad_nvfp4_weight(w_swapped, n_alignment=128, k_alignment=0)
    if scales.shape[0] != weight.shape[0]:
        pad_n = weight.shape[0] - scales.shape[0]
        scales = torch.nn.functional.pad(scales, (0, 0, 0, pad_n))
    # ... 应用 shuffle 并写入 layer.weight_scale_interleaved
python/sglang/multimodal_gen/runtime/platforms/cuda.py dependency-wiring

后端选择逻辑核心:更改默认后端为 trtllm 并简化 GEMM op 获取流程

# python/sglang/multimodal_gen/runtime/platforms/cuda.py@classmethod
@lru_cache(maxsize=1)
def get_modelopt_flashinfer_fp4_backend(cls) -> str:
    backend = envs.SGLANG_DIFFUSION_FLASHINFER_FP4_GEMM_BACKEND
    default_backend = 'trtllm' # PR 核心变更:默认后端统一为 trtllm
    if backend is None:
        return default_backend
    backend = backend.lower()
    # 支持多个别名
    backend = {
        'flashinfer_cudnn': 'cudnn',
        'flashinfer_cutlass': 'cutlass',
        'flashinfer_trtllm': 'trtllm',
        'trtllm': 'trtllm',
        'cudnn': 'cudnn',
        'auto': 'auto',
    }.get(backend, backend)
    if backend not in {'auto', 'cudnn', 'cutlass', 'trtllm'}:
        logger.warning('Unsupported backend %r, falling back to %r', backend, default_backend)
        return default_backend
    return backend@classmethod
@lru_cache(maxsize=1)
def get_modelopt_fp4_gemm_op(cls) -> tuple[Callable | None, str | None]:
    requested_backend = envs.SGLANG_DIFFUSION_FLASHINFER_FP4_GEMM_BACKEND
    try:
        from flashinfer import mm_fp4 as flashinfer_mm_fp4
        return flashinfer_mm_fp4, cls.get_modelopt_flashinfer_fp4_backend()
    except ImportError:
        logger.warning(
            'flashinfer.mm_fp4 unavailable, falling back to cutlass (requested: %r)',
            requested_backend or 'flashinfer_trtllm (default)',
        )
    try:
        from sgl_kernel import cutlass_scaled_fp4_mm as cutlass_fp4_gemm
        return cutlass_fp4_gemm, None
    except ImportError:
        return None, None

评论区精华

swizzled scale 转换逻辑仅对 trtllm 后端生效,其他后端可能出错 正确性

gemini-code-assist[bot] 在 modelopt_quant.py 的 diff 中指出,“The logic to un-swizzle weight scales from the checkpoint is currently implemented only within the flashinfer_trtllm backend block. However, checkpoints with swizzled scales will also be broken when falling back to other backends because those backends expect linear scales. This un-swizzling logic should be moved before the backend-specific blocks to ensure correct behavior across all backends.”

结论:PR 合并时该问题未被解决;当前实现仅转换局部 scales 变量,仅 trtllm 路径使用,其他后端仍使用原始 layer.weight_scale。 · unresolved

风险与影响

  1. 后向兼容性风险:已有用户依赖 cudnn 后端时未设置环境变量,切换默认值后可能面临不同的数值行为或性能变化(但性能应更好)。若用户希望保留原后端需显式设置 SGLANG_DIFFUSION_FLASHINFER_FP4_GEMM_BACKEND=cudnn
  2. swizzled scale 覆盖不全:根据 review 评论,非 trtllm 后端(如 cudnncutlass)遇到 swizzled layout 的检查点时,process_weights_after_loading 中转换后的 scales 未写回 layer.weight_scale,可能导致这些后端加载失败或产生错误结果。需要确认其他后端是否也期望 linear 布局。
  3. 依赖 flashinfer:新默认路径依赖 flashinfer.mm_fp4,若用户环境未安装 flashinfer 或版本不兼容,会自动回退到 cutlass。回退路径的数值一致性已通过测试验证,但性能可能下降。

用户影响:使用 diffusion NVFP4 的用户(FLUX、Wan2.2 等)将默认获得 flashinfer_trtllm 后端的性能提升。需关注兼容 checkpoints 的 layout 标记,若加载官方 FLUX.2 导出可自动识别 swizzled layout。
系统影响:无直接系统级变更。
团队影响:维护成本增加,需跟踪不同后端对 scale layout 的处理一致性。

后向兼容性风险 swizzled scale 未覆盖非 trtllm 后端 依赖 flashinfer 导致回退可能

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论