执行摘要
- 一句话:SM100默认使用Cute-DSL NVFP4量化,性能提升
- 推荐动作:本PR值得关注其通过注册custom_op实现CUDA graph兼容的技巧,以及在不同后端间自动选择的设计模式。对于要修改量化后端的开发者,是很好的参考。
功能与动机
在Blackwell(SM100)上,FlashInfer的Cute-DSL FP4量化核经过性能优化后全面超越原始CUDA核(参见flashinfer#2904),因此需要将其集成并设为默认后端,以获得最高性能。
实现拆解
- 在fp4_utils.py中引入FlashInfer的fp4_quantize,根据SM100判断自动选择cute-dsl或cuda backend。
- 使用register_custom_op_from_extern注册量化算子,并提供fake实现以支持CUDA graph捕获。
- 移除modelopt_quant.py中旧的fp4_quantize导入和fallback代码,改为从fp4_utils导入。
- 更新standard.py、flashinfer.py、compressed_tensors等文件中的导入路径,统一引用fp4_utils的fp4_quantize。
- 重写bench_fp4_quant.py基准脚本,直接对比sglang和flashinfer的量化性能,并增加绘图和CSV输出。
关键文件:
benchmark/kernels/quantization/bench_fp4_quant.py(模块 性能测试;类别 source;类型 benchmark;符号 _bench, benchmark, main, plot_speedup): 完全重写,用于对比sglang jit kernel与FlashInfer Cute-DSL的FP4量化性能,新增绘图函数和CSV输出,验证新后端性能优势。
python/sglang/srt/layers/quantization/fp4_utils.py(模块 量化工具;类别 source;类型 core-logic;符号 _round_up, _flashinfer_fp4_quantize_impl, _flashinfer_fp4_quantize_fake): 核心变更文件,集中了FP4量化函数的包装逻辑,通过register_custom_op_from_extern注册可被CUDA graph捕获的算子,并根据SM100自动选择cute-dsl后端。
python/sglang/srt/layers/quantization/modelopt_quant.py(模块 量化配置;类别 source;类型 dependency-wiring): 删除旧有的fp4_quantize导入逻辑,改为从fp4_utils导入,清理代码并统一入口。
关键符号:_flashinfer_fp4_quantize_impl, _flashinfer_fp4_quantize_fake, _round_up, benchmark, _bench, main, plot_speedup
关键源码片段
benchmark/kernels/quantization/bench_fp4_quant.py
完全重写,用于对比sglang jit kernel与FlashInfer Cute-DSL的FP4量化性能,新增绘图函数和CSV输出,验证新后端性能优势。
"""Benchmark FP4 quantize: sglang jit_kernel vs flashinfer.
Compares ``sglang.jit_kernel.nvfp4.scaled_fp4_quant`` against
``flashinfer.fp4_quantize`` over a sweep of (M, K) shapes.
Timing uses ``flashinfer.testing.bench_gpu_time`` (CUDA-graph based with
rotating-buffer cold-L2).
"""
import argparse
import itertools
import numpy as np
import torch
from flashinfer import fp4_quantize as flashinfer_fp4_quantize
from flashinfer.testing import bench_gpu_time
from sglang.jit_kernel.nvfp4 import scaled_fp4_quant
Ms = [1, 8, 32, 128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
Ks = [128, 256, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 5120, 6144, 8192, 16384]
def _bench(fn, input_args) -> float:
"""用 flashinfer 的 bench_gpu_time 进行 CUDA-graph 计时,返回中位数耗时(毫秒)。"""
times = bench_gpu_time(
fn=fn,
input_args=input_args,
use_cuda_graph=True,
dry_run_time_ms=25,
repeat_time_ms=100,
)
return float(np.median(times))
def benchmark(M: int, K: int, dtype: torch.dtype, device: str):
"""对给定形状 (M, K) 分别运行 sglang 和 flashinfer 的量化,返回各自耗时(毫秒)。"""
x = torch.randn(M, K, device=device, dtype=dtype)
global_scale = torch.ones(1, device=device, dtype=torch.float32)
sglang_ms = _bench(
lambda x, gs: scaled_fp4_quant(x, gs),
input_args=(x, global_scale),
)
flashinfer_ms = _bench(
lambda x, gs: flashinfer_fp4_quantize(x, gs, backend="cute-dsl"),
input_args=(x, global_scale),
)
return sglang_ms, flashinfer_ms
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16")
parser.add_argument("--device", default="cuda")
parser.add_argument("--csv", type=str, default=None)
parser.add_argument("--plot", type=str, default=None)
args = parser.parse_args()
dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
rows = []
header = f"{'M':>8} {'K':>8} {'sglang(us)':>12} {'flashinfer(us)':>16} {'speedup':>10}"
print(header)
print("-" * len(header))
for M, K in itertools.product(Ms, Ks):
try:
sglang_ms, flashinfer_ms = benchmark(M, K, dtype, args.device)
except Exception as e:
print(f"{M:>8} {K:>8} skipped: {e}")
continue
# 转换为微秒
sglang_us = sglang_ms * 1e3
flashinfer_us = flashinfer_ms * 1e3
speedup = flashinfer_us / sglang_us # >1 表示 flashinfer 更快
print(f"{M:>8} {K:>8} {sglang_us:>12.3f} {flashinfer_us:>16.3f} {speedup:>10.3f}")
rows.append((M, K, sglang_us, flashinfer_us, speedup))
if args.csv:
with open(args.csv, "w") as f:
f.write("M,K,sglang_us,flashinfer_us,speedup_flashinfer_over_sglang\n")
for M, K, s, fi, sp in rows:
f.write(f"{M},{K},{s:.6f},{fi:.6f},{sp:.6f}\n")
print(f"Saved CSV to {args.csv}")
if args.plot:
plot_speedup(rows, args.plot)
if __name__ == "__main__":
main()
评论区精华
Fridge003建议将fp4_quantize从modelopt_quant.py移到fp4_utils.py,作者立即执行。对于是否删除未使用的jit kernel,作者表示可作为后续改进。关于piecewise CUDA graph兼容性,作者展示了启动服务器的成功日志,证明register_custom_op方案可以正常工作。
- fp4_quantize函数位置重构 (design): b8zhong接受建议并移动。
- 未使用的jit kernel是否删除 (style): b8zhong表示可作为后续工作,暂不删除。
- piecewise CUDA graph兼容性测试 (correctness): b8zhong展示了启动服务器并使用CUDA graph的成功日志,证明register_custom_op方案有效。
风险与影响
- 风险:依赖FlashInfer新版本(需支持cute-dsl backend),若flashinfer不可用则fp4_quantize为None,需注意fallback。默认backend改为cute-dsl,在非SM100设备上自动回退到cuda,但可能因flashinfer版本差异导致行为不一致。旧代码中依赖sglang jit kernel的路径虽未被删除,但已不再使用,若有外部代码直接引用可能失效。
- 影响:用户:在搭载Blackwell GPU的系统上,FP4量化推理速度提升明显;开发者:量化函数集中管理,后续维护更容易;测试:基准测试脚本重写,支持更全面的shape覆盖和绘图。
- 风险标记:依赖FlashInfer版本, CUDA Graph兼容依赖custom_op, 移除旧导入对外部代码的影响
关联脉络
参与讨论