Prhub

#21431 [Diffusion] [AMD] Online MXFP4 and FP8 Quantization for Multimodal Generation

原始 PR 作者 ColinZ22 合并时间 2026-05-14 08:52 文件变更 10 提交数 20 评论 27 代码增减 +417 / -17

执行摘要

为 multimodal 扩散添加在线 MXFP4/FP8 量化

PR body 指出目标是为 Z-Image-Turbo 和 Wan 2.2 等模型添加在线 MXFP4(AMD)和 FP8 量化,降低显存并加速推理。性能对比显示 transformer 尺寸减少49-72%,峰值内存减少29-43%,生成时间减少5-18%,且 CLIP 分数保持稳定,证明量化对质量影响很小。

本 PR 值得精读,尤其关注量化配置与线性方法的扩展点设计get_quant_methodpacked_modules_mapping 注入),以及跨模型传递量化参数的模式(在 FeedForward 等子模块中添加 quant_configprefix 参数)。对于计划在 diffusion 模型上支持新量化后端的开发者,这是很好的参考示例。

讨论亮点
  • 代码风格:缓存 _is_hip 结果:mickqian 询问能否避免模块级 local 变量 _is_hip。ColinZ22 回应这是代码库惯例,类似 activation.pyfp8_utils.py,用于性能缓存。结论:维持原样。
  • 文档覆盖:mickqian 要求更新 cli.mdquantization.md。ColinZ22 随后添加了对应文档,满足要求。
  • FP8 路径损坏:HaiShaw 报告 main 上 FP8 路径已损坏(--quantization fp8 报错)。ColinZ22 建议在 Fp8LinearMethod.apply 上添加 @torch.compiler.disable 以规避 Inductor 不可降低 aten._scaled_mm 的问题。该修复不在本 PR 范围内,但提供了临时方案。
  • quantization-ignored-layers 实际使用:avjves 在 issue 评论中指出该 CLI 参数似乎未被使用。后续提交通过 resolve_transformer_quant_load_specpacked_modules_mapping 注册到量化配置,使忽略层逻辑生效。该担忧已解决。

实现拆解

  1. 新增 MXFP4 量化配置与线性方法:在 mxfp4.py 中定义 Mxfp4Config(继承 QuantizationConfig)和 Mxfp4LinearMethod(继承 LinearMethodBase)。初始化时条件导入 AITER 的 gemm_a4w4shuffle_weightdynamic_mxfp4_quant,并在 get_quant_method 中添加跳过小输出层(output_size < 256)和忽略层列表支持。

  2. 修改模型定义透传量化参数:在 zimage.pyFeedForwardZImageBlock 构造函数中添加 quant_configprefix 参数,传递给子线性层,确保量化方法能递归应用到各层。同时为 ZImageTransformer2DModel 添加 packed_modules_mapping 字典,使 is_layer_skipped 能正确匹配融合层。

  3. 扩展 FP8 在线量化支持:在 fp8.pyFp8Config 中添加 packed_modules_mapping 参数,并修改 get_quant_method 调用 is_layer_skipped 时传入该映射,从而支持 --quantization-ignored-layers 对 FP8 路径生效。

  4. 添加 CLI 与加载器配套:在 server_args.py 新增 --quantization--quantization-ignored-layers 参数,更新帮助文本。在 transformer_load_utils.pyresolve_transformer_quant_load_spec 中,从模型类获取 packed_modules_mapping 并注入到量化配置对象中。在 fsdp_load.py 中添加 weight_scaleinput_scale 加载键以支持量化参数。

  5. Flash Attention 回退以兼容 ROCm:在 flash_attention_v3.py 中,当 sgl-kernel 的 FA3 不支持时(如 ROCm),回退到 flash_attn 包的 FA2 实现,确保 MXFP4 量化在 AMD 上也能使用 Flash Attention。

  6. 更新文档:在 quantization.mdcli.md 中添加在线量化的使用说明、选项示例和注意事项。

文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4.py 量化层 added 9.23
python/sglang/multimodal_gen/runtime/models/dits/zimage.py 模型定义 modified 7.18
python/sglang/jit_kernel/flash_attention_v3.py 注意力计算 modified 6.74
python/sglang/multimodal_gen/runtime/server_args.py 服务器配置 modified 6.13
python/sglang/multimodal_gen/runtime/layers/quantization/fp8.py 量化器 modified 5.73

关键符号

Mxfp4Config.__init__ Mxfp4Config.get_quant_method Mxfp4LinearMethod.create_weights Mxfp4LinearMethod.process_weights_after_loading Mxfp4LinearMethod.apply FeedForward.__init__ Fp8Config.__init__ resolve_transformer_quant_load_spec flash_attn_varlen_func

关键源码片段

python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4.py core-logic

核心新增文件,实现 MXFP4 量化配置(Mxfp4Config)与线性方法(Mxfp4LinearMethod),包括条件导入 AITER 内核、小输出层跳过、忽略层列表等关键逻辑。

# python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4.py
import logging
import torch
from sglang.srt.utils import is_hip, mxfp_supported# 缓存 is_hip() 结果避免重复调用
_is_hip = is_hip()if _is_hip:
    try:
        import aiter
        from aiter.ops.gemm_op_a4w4 import gemm_a4w4
        from aiter.ops.shuffle import shuffle_weight
        from aiter.utility.fp4_utils import dynamic_mxfp4_quant
    except ImportError as e:
        # 若 AITER 不可用,所有 kernel 指针置 None,后续禁用 MXFP4
        logger.warning(f"aiter MXFP4 kernels not available: {e}")
        aiter = None
        shuffle_weight = None
        dynamic_mxfp4_quant = None
        gemm_a4w4 = None# gemm_a4w4 在输出维度 N < 256 时精度下降,因此跳过小输出层
_MXFP4_MIN_OUTPUT_DIM = 256class Mxfp4Config(QuantizationConfig):
    """MXFP4 量化配置,适用于 diffusion 模型在线量化"""
​
    def __init__(self, is_checkpoint_mxfp4_serialized=False, ignored_layers=None, packed_modules_mapping=None):
        super().__init__()
        self.is_checkpoint_mxfp4_serialized = is_checkpoint_mxfp4_serialized
        self.ignored_layers = ignored_layers or []
        self.packed_modules_mapping = packed_modules_mapping or {}
​
    @classmethod
    def get_name(cls) -> str:
        return "mxfp4"
​
    @classmethod
    def get_min_capability(cls) -> int:
        return 95 # 对应 gfx95x,但仍建议使用 mxfp_supported() 动态判断
​
    def get_quant_method(self, layer, prefix: str):
        # 只量化 LinearBase 子类
        if isinstance(layer, LinearBase):
            # 若层前缀匹配忽略列表,则返回未量化方法
            if is_layer_skipped(prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping):
                return UnquantizedLinearMethod()
            # 输出维度 < 256 时保持全精度以避免 ASM kernel 精度问题
            output_size = getattr(layer, "output_size", None)
            if output_size is not None and output_size < _MXFP4_MIN_OUTPUT_DIM:
                return UnquantizedLinearMethod()
            return Mxfp4LinearMethod(self)
        return None
python/sglang/multimodal_gen/runtime/models/dits/zimage.py data-contract

Z-Image 模型适配,在 FeedForward 和 ZImageBlock 中透传 quant_config 与 prefix,使子线性层能被 MXFP4/FP8 量化;添加 packed_modules_mapping 映射支持。

# python/sglang/multimodal_gen/runtime/models/dits/zimage.py
class FeedForward(nn.Module):
    # 新增 quant_config 和 prefix 参数,使子线性层能接收量化配置
    def __init__(self, dim: int, hidden_dim: int,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        # 将 quant_config 和 prefix 传递给 MergedColumnParallelLinear 与 RowParallelLinear
        self.w13 = MergedColumnParallelLinear(
            dim, [hidden_dim, hidden_dim], bias=False, gather_output=False,
            quant_config=quant_config, prefix=f"{prefix}.w13")
        self.w2 = RowParallelLinear(
            hidden_dim, dim, bias=False, input_is_parallel=True,
            quant_config=quant_config, prefix=f"{prefix}.w2")
        self.act = SiluAndMul()# ZImageTransformer2DModel 类中新增 packed_modules_mapping 静态变量
class ZImageTransformer2DModel(CachableDiT, OffloadableDiTMixin):
    packed_modules_mapping = {
        "w13": ["w1", "w3"], # 映射融合层名称,供 is_layer_skipped 正确识别
    }

评论区精华

代码风格:缓存 is_hip() 结果 style

mickqian 询问能否避免模块级 local `_is_hip`,ColinZ22 解释这是代码库惯例(如 activation.py, fp8_utils.py),用于减少函数调用开销。

结论:维持原样,保留缓存变量。 · 已解决

文档更新要求 documentation

mickqian 要求在新 CLI 参数文档中提及 `cli.md` 和 `quantization.md`。

结论:ColinZ22 添加了对应文档,更新后满足要求。 · 已解决

FP8 路径损坏与 torch.compile 兼容性 正确性

HaiShaw 报告 `--quantization fp8` 在 main 上出错,yichiche 询问与 `--enable-torch-compile` 兼容性。ColinZ22 建议在 `Fp8LinearMethod.apply` 上添加 `@torch.compiler.disable` 以绕开 Inductor 限制。

结论:本 PR 未修复该问题,但提供了临时方案;后续需跟进修复。 · unresolved

quantization-ignored-layers 参数未被实际使用 正确性

avjves 在 issue 评论中指出 `--quantization-ignored-layers` 的值似乎被忽略,未传递到量化逻辑。

结论:后续提交通过 `resolve_transformer_quant_load_spec` 从模型类获取 `packed_modules_mapping` 并注入量化配置,使忽略层逻辑生效。 · 已解决

风险与影响

  • 硬件依赖:MXFP4 量化依赖 AMD MI350+(gfx95x)和 AITER 库,非 AMD 平台或缺少 AITER 时将回退到未量化路径,但可能引入意料之外的 import 错误。
  • FP8 兼容性:当前 --quantization fp8--enable-torch-compile 不兼容(aten._scaled_mm 无法被 Inductor 降低),虽然本 PR 未修复此问题,但用户可能遇到错误。
  • 精度衰退风险:尽管 CLIP 分数验证了图像质量,但层跳过逻辑(小输出层保持未量化)和包装层映射可能影响量化一致性,尤其当 packed_modules_mapping 不完整时。
  • 忽略层配置脆弱性:用户提供的忽略层模式依赖层前缀字符串,若模型结构更新导致前缀变化,配置可能失效。
  • 测试覆盖缺失:本 PR 未包含专门的测试文件,量化路径的可靠性依赖集成测试和 CI,存在回归风险。
  • 用户影响:AMD 用户可显著节省显存(MXFP4 最高 72%)并提升生成速度(21%),FP8 用户也可获得约 49% 的变压器压缩。新 CLI 参数向后兼容,不影响现有工作流。
  • 系统影响:增加对 AITER 内核的条件依赖(仅 AMD 平台加载),不影响其他硬件。GPU 内存占用降低,有利于多实例部署。
  • 团队影响:需维护 MXFP4 量化端到端链路,包括 aiter 版本兼容性。与已有 FP8 量化路径共享部分基础设施(如 is_layer_skipped),减少维护负担。
依赖 AMD MI350+ 与 AITER 库 FP8 与 torch.compile 不兼容 忽略层配置依赖层前缀字符串 缺少单元测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论