Prhub

#23745 Use Cute-DSL NVFP4 quantization kernels

原始 PR 作者 b8zhong 合并时间 2026-05-11 15:40 文件变更 8 提交数 1 评论 9 代码增减 +212 / -141

执行摘要

SM100 默认使用 Cute-DSL NVFP4 量化,性能提升

在Blackwell(SM100)上,FlashInfer的Cute-DSL FP4量化核经过性能优化后全面超越原始CUDA核(参见flashinfer#2904),因此需要将其集成并设为默认后端,以获得最高性能。

本PR值得关注其通过注册custom_op实现CUDA graph兼容的技巧,以及在不同后端间自动选择的设计模式。对于要修改量化后端的开发者,是很好的参考。

讨论亮点

Fridge003建议将fp4_quantize从modelopt_quant.py移到fp4_utils.py,作者立即执行。对于是否删除未使用的jit kernel,作者表示可作为后续改进。关于piecewise CUDA graph兼容性,作者展示了启动服务器的成功日志,证明register_custom_op方案可以正常工作。

实现拆解

  1. 在fp4_utils.py中引入FlashInfer的fp4_quantize,根据SM100判断自动选择cute-dsl或cuda backend。
  2. 使用register_custom_op_from_extern注册量化算子,并提供fake实现以支持CUDA graph捕获。
  3. 移除modelopt_quant.py中旧的fp4_quantize导入和fallback代码,改为从fp4_utils导入。
  4. 更新standard.py、flashinfer.py、compressed_tensors等文件中的导入路径,统一引用fp4_utils的fp4_quantize。
  5. 重写bench_fp4_quant.py基准脚本,直接对比sglang和flashinfer的量化性能,并增加绘图和CSV输出。
文件 模块 状态 重要度
benchmark/kernels/quantization/bench_fp4_quant.py 性能测试 modified 8.48
python/sglang/srt/layers/quantization/fp4_utils.py 量化工具 modified 8.02
python/sglang/srt/layers/quantization/modelopt_quant.py 量化配置 modified 6.06

关键符号

_flashinfer_fp4_quantize_impl _flashinfer_fp4_quantize_fake _round_up benchmark _bench main plot_speedup

关键源码片段

benchmark/kernels/quantization/bench_fp4_quant.py benchmark

完全重写,用于对比 sglang jit kernel 与 FlashInfer Cute-DSL 的 FP4 量化性能,新增绘图函数和 CSV 输出,验证新后端性能优势。

"""Benchmark FP4 quantize: sglang jit_kernel vs flashinfer.Compares ``sglang.jit_kernel.nvfp4.scaled_fp4_quant`` against
``flashinfer.fp4_quantize`` over a sweep of (M, K) shapes.Timing uses ``flashinfer.testing.bench_gpu_time`` (CUDA-graph based with
rotating-buffer cold-L2).
"""import argparse
import itertoolsimport numpy as np
import torch
from flashinfer import fp4_quantize as flashinfer_fp4_quantize
from flashinfer.testing import bench_gpu_timefrom sglang.jit_kernel.nvfp4 import scaled_fp4_quantMs = [1, 8, 32, 128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
Ks = [128, 256, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 5120, 6144, 8192, 16384]
​
​
def _bench(fn, input_args) -> float:
    """用 flashinfer 的 bench_gpu_time 进行 CUDA-graph 计时,返回中位数耗时(毫秒)。"""
    times = bench_gpu_time(
        fn=fn,
        input_args=input_args,
        use_cuda_graph=True,
        dry_run_time_ms=25,
        repeat_time_ms=100,
    )
    return float(np.median(times))
​
​
def benchmark(M: int, K: int, dtype: torch.dtype, device: str):
    """对给定形状 (M, K) 分别运行 sglang 和 flashinfer 的量化,返回各自耗时(毫秒)。"""
    x = torch.randn(M, K, device=device, dtype=dtype)
    global_scale = torch.ones(1, device=device, dtype=torch.float32)
​
    sglang_ms = _bench(
        lambda x, gs: scaled_fp4_quant(x, gs),
        input_args=(x, global_scale),
    )
    flashinfer_ms = _bench(
        lambda x, gs: flashinfer_fp4_quantize(x, gs, backend="cute-dsl"),
        input_args=(x, global_scale),
    )
    return sglang_ms, flashinfer_ms
​
​
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16")
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--csv", type=str, default=None)
    parser.add_argument("--plot", type=str, default=None)
    args = parser.parse_args()
​
    dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
​
    rows = []
    header = f"{'M':>8} {'K':>8} {'sglang(us)':>12} {'flashinfer(us)':>16} {'speedup':>10}"
    print(header)
    print("-" * len(header))
​
    for M, K in itertools.product(Ms, Ks):
        try:
            sglang_ms, flashinfer_ms = benchmark(M, K, dtype, args.device)
        except Exception as e:
            print(f"{M:>8} {K:>8}  skipped: {e}")
            continue
        # 转换为微秒
        sglang_us = sglang_ms * 1e3
        flashinfer_us = flashinfer_ms * 1e3
        speedup = flashinfer_us / sglang_us # >1 表示 flashinfer 更快
        print(f"{M:>8} {K:>8} {sglang_us:>12.3f} {flashinfer_us:>16.3f} {speedup:>10.3f}")
        rows.append((M, K, sglang_us, flashinfer_us, speedup))
​
    if args.csv:
        with open(args.csv, "w") as f:
            f.write("M,K,sglang_us,flashinfer_us,speedup_flashinfer_over_sglang\n")
            for M, K, s, fi, sp in rows:
                f.write(f"{M},{K},{s:.6f},{fi:.6f},{sp:.6f}\n")
        print(f"Saved CSV to {args.csv}")
​
    if args.plot:
        plot_speedup(rows, args.plot)
​
​
if __name__ == "__main__":
    main()

评论区精华

fp4_quantize 函数位置重构 设计

Fridge003 建议将 fp4_quantize 从 modelopt_quant.py 移到 fp4_utils.py,认为更合适。

结论:b8zhong 接受建议并移动。 · 已解决

未使用的 jit kernel 是否删除 style

Fridge003 询问如果 jit kernel 不再使用,应删除。

结论:b8zhong 表示可作为后续工作,暂不删除。 · deferred

piecewise CUDA graph 兼容性测试 正确性

Fridge003 询问是否测试过在 piecewise CUDA graph 下的工作情况。

结论:b8zhong 展示了启动服务器并使用 CUDA graph 的成功日志,证明 register_custom_op 方案有效。 · 已解决

风险与影响

依赖FlashInfer新版本(需支持cute-dsl backend),若flashinfer不可用则fp4_quantize为None,需注意fallback。默认backend改为cute-dsl,在非SM100设备上自动回退到cuda,但可能因flashinfer版本差异导致行为不一致。旧代码中依赖sglang jit kernel的路径虽未被删除,但已不再使用,若有外部代码直接引用可能失效。

用户:在搭载Blackwell GPU的系统上,FP4量化推理速度提升明显;开发者:量化函数集中管理,后续维护更容易;测试:基准测试脚本重写,支持更全面的shape覆盖和绘图。

依赖 FlashInfer 版本 CUDA Graph 兼容依赖 custom_op 移除旧导入对外部代码的影响

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论