执行摘要
- 一句话:迁移Wan2.2 ModelOpt CI至NVIDIA官方FP8/NVFP4检查点
- 推荐动作:建议阅读:该PR演示了如何安全地迁移外部依赖并调整内部默认值。值得关注的设计决策是swap_weight_nibbles的fallback链,以及如何通过checkpoint_uses_packed_qkv保持向后兼容。测试修复的根因分析也值得学习。
功能与动机
PR源自#25483的重新提交,旨在使用NVIDIA官方Diffusers FP8/NVFP4检查点替换旧的lmsys transformer overrides,减少维护负担并使Wan2.2量化路径与官方仓库对齐。根因是#25483中默认swap_weight_nibbles变更导致B200 JIT测试失败,本次修复通过显式指定合成检查点参数。
实现拆解
- 变更默认NVFP4配置(modelopt_quant.py):将ModelOptFp4Config的swap_weight_nibbles默认值从True改为False,并在from_config方法中调整fallback顺序,优先使用checkpoint_uses_packed_qkv作为回退键。
- 更新配置合并逻辑(transformer_load_utils.py):在_merge_modelopt_fp4_configs中调整属性复制顺序,现在先处理checkpoint_uses_packed_qkv,再处理swap_weight_nibbles且默认回退False。
- 注册新模型路径(registry.py):为Wan2_2_T2V_A14B注册新的模型路径nvidia/Wan2.2-T2V-A14B-Diffusers-NVFP4,使得直接通过--model-path加载成为可能。
- 修复B200 JIT测试(test_diffusion_nvfp4_scaled_mm.py):在合成预交换检查点测试中显式传递swap_weight_nibbles=True,以匹配经过预先交换的权重数据。
- 更新CI测试用例和文档(gpu_cases.py、testcase_configs.py、quantization.mdx):将Wan2.2测试用例从--transformer-path改为直接使用--model-path指定官方仓库;启用run_consistency_check;更新环境变量;刷新文档表格以反映新检查点来源。
关键文件:
python/sglang/multimodal_gen/runtime/layers/quantization/modelopt_quant.py(模块 量化层;类别 source;类型 data-contract;符号 ModelOptFp4Config, from_config, process_weights_after_loading): 该文件是量化配置的核心,修改了ModelOptFp4Config的swap_weight_nibbles默认值(从True改为False),并调整了from_config方法的fallback逻辑,影响所有NVFP4检查点的加载行为。
python/sglang/multimodal_gen/tools/build_modelopt_nvfp4_transformer.py(模块 构建工具;类别 source;类型 data-contract;符号 build_modelopt_nvfp4_transformer): 构建工具的默认值解析逻辑被简化:移除了对flux1-nvfp4预设的特判,统一默认值为False。影响所有通过该工具构建的混合精度检查点。
python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py(模块 加载器;类别 source;类型 core-logic;符号 _merge_modelopt_fp4_configs): 配置合并函数调整了属性复制顺序:先复制checkpoint_uses_packed_qkv,再处理swap_weight_nibbles且默认回退False,确保NVIDIA官方检查点的加载正确性。
python/sglang/multimodal_gen/registry.py(模块 注册中心;类别 source;类型 core-logic;符号 _register_configs): 注册了新的NVIDIA官方NVFP4模型路径,使得Wan2.2可以直接通过--model-path加载,无需指定transformer-path。
python/sglang/multimodal_gen/test/server/gpu_cases.py(模块 GPU测试;类别 test;类型 test-coverage;符号 MODELOPT_WAN22_FP8_MODEL, MODELOPT_WAN22_NVFP4_B200_ENV_VARS, MODELOPT_WAN22_NVFP4_MODEL): 主要测试变更:将Wan2.2测试用例从--transformer-path改为直接使用--model-path加载官方仓库,并启用一致性检查。
docs_new/docs/sglang-diffusion/quantization.mdx(模块 文档;类别 other;类型 core-logic): 文档更新:反映新的检查点来源矩阵,从lmsys仓库改为NVIDIA官方仓库,并调整了加载方式说明。
关键符号:ModelOptFp4Config.init, ModelOptFp4Config.from_config, ModelOptFp4Config.process_weights_after_loading, _merge_modelopt_fp4_configs, build_modelopt_nvfp4_transformer, _register_configs
关键源码片段
python/sglang/multimodal_gen/runtime/layers/quantization/modelopt_quant.py
该文件是量化配置的核心,修改了ModelOptFp4Config的swap_weight_nibbles默认值(从True改为False),并调整了from_config方法的fallback逻辑,影响所有NVFP4检查点的加载行为。
class ModelOptFp4Config(ModelOptQuantConfig):
"""Config class for NVFP4."""
def __init__(
self,
is_checkpoint_nvfp4_serialized: bool = False,
group_size: int = None,
exclude_modules: List[str] = None,
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
checkpoint_uses_packed_qkv: bool = False,
swap_weight_nibbles: bool = False, # 默认值从 True 改为 False,匹配官方检查点加载顺序
) -> None:
super().__init__(exclude_modules, packed_modules_mapping)
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning("Detected nvfp4 checkpoint...")
self.group_size = group_size
self.checkpoint_uses_packed_qkv = checkpoint_uses_packed_qkv
self.swap_weight_nibbles = swap_weight_nibbles
@classmethod
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
group_size = None
exclude_modules = []
swap_weight_nibbles = False # 默认值从 True 改为 False
# 扁平格式 (config.json quantization_config)
quant_method = config.get("quant_algo")
if quant_method is not None:
group_size = config.get("group_size")
if group_size is None:
config_groups = config.get("config_groups", {})
if config_groups:
first_group = next(iter(config_groups.values()), {})
group_size = first_group.get("weights", {}).get("group_size")
exclude_modules = config.get("ignore", [])
swap_weight_nibbles = config.get(
"swap_weight_nibbles",
config.get("checkpoint_uses_packed_qkv", False), # 新增 fallback 到 checkpoint_uses_packed_qkv
)
else:
# 嵌套格式 (hf_quant_config.json)
try:
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
group_size = ModelOptFp4Config.common_group_size(config)
exclude_modules = quant_config.get("exclude_modules", [])
swap_weight_nibbles = quant_config.get(
"swap_weight_nibbles",
config.get(
"swap_weight_nibbles",
config.get("checkpoint_uses_packed_qkv", False), # 三层 fallback
),
)
except (ValueError, KeyError):
raise ValueError("Cannot find 'quant_algo' in quantization config.")
# ... 后续省略
python/sglang/multimodal_gen/tools/build_modelopt_nvfp4_transformer.py
构建工具的默认值解析逻辑被简化:移除了对flux1-nvfp4预设的特判,统一默认值为False。影响所有通过该工具构建的混合精度检查点。
def build_modelopt_nvfp4_transformer(
*,
base_transformer_dir: str,
modelopt_hf_dir: str,
output_dir: str,
pattern_preset: str = "none",
keep_bf16_patterns: Sequence[str] | None = None,
swap_weight_nibbles: bool | None = None, # 由调用方控制,不再有 preset 依赖
overwrite: bool = False,
) -> dict[str, int | bool]:
source_dir = _resolve_transformer_dir(modelopt_hf_dir)
base_dir = _resolve_transformer_dir(base_transformer_dir)
patterns = _preset_patterns(pattern_preset)
if keep_bf16_patterns:
patterns.extend(keep_bf16_patterns)
# 移除了 pattern_preset 为 "flux1-nvfp4" 时返回 True 的特判,统一默认 False
resolved_swap_weight_nibbles = (
swap_weight_nibbles if swap_weight_nibbles is not None else False
)
output_config = _updated_quant_config(
_load_config(source_dir),
fallback_patterns=patterns,
swap_weight_nibbles=resolved_swap_weight_nibbles,
)
# ... 后续省略
python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py
配置合并函数调整了属性复制顺序:先复制checkpoint_uses_packed_qkv,再处理swap_weight_nibbles且默认回退False,确保NVIDIA官方检查点的加载正确性。
def _merge_modelopt_fp4_configs(
existing_config: Optional[ModelOptFp4Config],
inferred_config: Optional[ModelOptFp4Config],
) -> Optional[ModelOptFp4Config]:
"""Merge FP4 configs prioritizing inferred exclude list but preserving repo-level knobs."""
if inferred_config is None:
return existing_config
if _get_quant_config_name(inferred_config) != "modelopt_fp4":
return existing_config or inferred_config
if existing_config is None:
return inferred_config
if _get_quant_config_name(existing_config) != "modelopt_fp4":
return existing_config
# ... exclude_modules 合并逻辑省略
# 关键变更:先处理 checkpoint_uses_packed_qkv,再处理 swap_weight_nibbles
inferred_config.checkpoint_uses_packed_qkv = getattr(
inferred_config, "checkpoint_uses_packed_qkv", False
) or getattr(existing_config, "checkpoint_uses_packed_qkv", False)
inferred_config.swap_weight_nibbles = getattr(
inferred_config, "swap_weight_nibbles", False # 默认值从 True 改为 False
) or getattr(existing_config, "swap_weight_nibbles", False)
if getattr(inferred_config, "group_size", None) is None:
inferred_config.group_size = getattr(existing_config, "group_size", None)
return inferred_config
评论区精华
PR没有直接review评论,但PR body详细描述了根因:
"#25483 intentionally changes the ModelOpt FP4 fallback default for swap_weight_nibbles to False so NVIDIA full-repo checkpoints without the field load in runtime order."
"The B200 JIT test builds a synthetic checkpoint by pre-swapping nibbles with _swap_fp4_nibbles(weight_fp4), but did not pass the knob explicitly, so after the default flip its expected weight stayed one nibble-swap away from the processed layer weight."
该讨论已通过显式传递参数解决。
- swap_weight_nibbles默认值变更导致的测试失败及修复 (correctness): 在合成检查点测试中显式传递swap_weight_nibbles=True以匹配预交换数据。
风险与影响
- 风险:
- 默认值变更影响广泛:swap_weight_nibbles从True变为False会影响所有NVFP4检查点加载,尤其是未显式设置该字段的配置。虽然已通过fallback到checkpoint_uses_packed_qkv缓解,但依赖旧默认值的外部工作流可能出错。
- 模型路径变更:从lmsys transformer改为NVIDIA官方仓库,若官方仓库不可用或更新,CI可能失败。
- 多文件一致性:变更涉及quantization配置、加载器、注册表、测试配置等6个核心文件,任何不一致可能导致运行时错误。
- JIT测试局限:修复仅针对特定合成检查点,若其他测试也依赖旧默认值可能未被覆盖。
- 影响:影响范围包括:
- 用户:使用Wan2.2 NVFP4检查点的用户现在可以指定官方仓库ID直接加载(例如
--model-path nvidia/Wan2.2-T2V-A14B-Diffusers-NVFP4),无需单独准备transformer路径。但需注意默认加载行为变化。
- 系统:NVFP4加载路径统一,减少维护分支。
- CI:B200 CI现在使用官方NVFP4检查点,并启用一致性检查;Wan2.2 FP8/NVFP4测试用例简化,不再依赖lmsys中间仓库。
- 文档:quantization.mdx更新了支持矩阵,移除了lmsys transformer引用。
- 风险标记:默认值变更, 依赖官方检查点可用性, 多模块协调变更, JIT测试脆弱性
关联脉络
- PR #25483 [diffusion] Move Wan2.2 ModelOpt CI/docs from old lmsys overrides to NVIDIA full Diffusers FP8/NVFP4 checkpoints: 本PR是25483的重新提交(reland),修复了其中默认值变更导致的测试失败问题
参与讨论