Prhub

#26327 [diffusion] fix: fix diffusion LoRA consistency cases

原始 PR 作者 mickqian 合并时间 2026-05-28 06:29 文件变更 6 提交数 9 评论 2 代码增减 +74 / -13

执行摘要

修复 Diffusion LoRA 精度与一致性验证

根据 PR body,主要动机是修复 LoRA 合并精度(默认 FP32 合并)、保护缓存 LoRA 条目中的合并强度、加载适配器级别的 lora_alpha,以及为 LoRA API 生成添加一致性验证,确保启动/应用 LoRA、merge_lora_weights、set_lora、动态切换和多 LoRA 切换路径生成像素与 GT 比较。

该 PR 修复了 diffusion LoRA 多个边界情况,并加强了测试覆盖,值得 review 和 merge。特别关注 FP32 合并默认值变更和 lora_alpha 加载的设计决策。

讨论亮点

本次 PR 无公开 review 讨论。

实现拆解

  1. 合并精度提升:在 linear.py_should_merge_in_fp32 中,将环境变量 SGLANG_DIFFUSION_LORA_MERGE_FP32 的默认值从 0 改为 1,即默认启用 FP32 合并;同时在 merge_lora_weights 中,当传入新 strength 时同步更新 lora_weights_list 中每个条目的 strength。
  2. lora_alpha 加载与缓存:在 lora_pipeline.pyload_lora_adapter 中,新增从 adapter_config.json 读取 lora_alpha,并存入新字典 loaded_adapter_alphas;在 _apply_lora_to_layers 中,当查找 alpha 时优先使用该缓存值,若没有则回退到 rank。
  3. set_lora 路径复用:在 set_lora 中,当传入的 lora_paths 某项为 None 时,从 loaded_adapter_paths 中根据 nickname 获取已有路径,避免重复加载。
  4. 一致性测试增强:在 test_server_common.py 中新增 _validate_lora_consistency 方法,并在 _test_lora_api_functionality_test_lora_dynamic_switch_e2e 中增加对合并、设置、动态切换后生成内容的一致性检查。
  5. CI 数据基线更新:更新 test_utils.py 中的 SGL_TEST_FILES_CI_DATA_REVISION 指向包含新 LoRA GT 的提交。
  6. 其他配套:修复 test_lora_pipeline.py 中测试夹具未初始化 loaded_adapter_alphas 的问题,并为 gpu_cases.py 中 Wan2.2 案例添加 --lora-merge-mode dynamic 参数。
文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py LoRA 管线 modified 6.83
python/sglang/multimodal_gen/runtime/layers/lora/linear.py LoRA 合并 modified 6.02
python/sglang/multimodal_gen/test/server/test_server_common.py 服务测试 modified 5.96
python/sglang/multimodal_gen/test/test_utils.py 测试工具 modified 3.92
python/sglang/multimodal_gen/test/unit/test_lora_pipeline.py 单元测试 modified 3.28
python/sglang/multimodal_gen/test/server/gpu_cases.py GPU 用例 modified 3.25

关键符号

_should_merge_in_fp32 merge_lora_weights load_lora_adapter set_lora _apply_lora_to_layers _validate_lora_consistency _test_lora_api_functionality _test_lora_dynamic_switch_e2e

关键源码片段

python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py core-logic

核心 LoRA 管线,新增 lora_alpha 加载、缓存和 set_lora 路径复用逻辑。

# python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py
# 关键变更:加载适配器级别的 lora_alpha 并缓存def load_lora_adapter(self, lora_path: str, lora_nickname: str) -> None:
    # ... 下载和归一化 state_dict ...
    raw_state_dict = load_file(lora_local_path)
    lora_state_dict = normalize_lora_state_dict(raw_state_dict, logger=logger)
​
    # 尝试从 adapter_config.json 读取 lora_alpha(适配器级别)
    adapter_lora_alpha = None
    adapter_config_path = os.path.join(
        os.path.dirname(lora_local_path), "adapter_config.json"
    )
    if os.path.isfile(adapter_config_path):
        with open(adapter_config_path, encoding="utf-8") as f:
            adapter_config = json.load(f)
        if adapter_config.get("lora_alpha") is not None:
            adapter_lora_alpha = int(adapter_config["lora_alpha"])
    # ... 清空或初始化 self.lora_adapters[lora_nickname] ...
​
    # 填充 lora_adapters 字典
    for target_name, weight in lora_state_dict.items():
        # ... 过滤和映射 ...
        self.lora_adapters[lora_nickname][target_name] = weight.to(self.device)
​
    # 缓存路径和 alpha
    self.loaded_adapter_paths[lora_nickname] = lora_path
    self.loaded_adapter_alphas[lora_nickname] = adapter_lora_alpha # 新增:缓存适配器级 alpha
​
    logger.info("Rank %d: loaded LoRA adapter %s", rank, lora_path)
python/sglang/multimodal_gen/runtime/layers/lora/linear.py core-logic

LoRA 合并精度核心逻辑,修改默认 FP32 合并并修复 strength 传播。

# python/sglang/multimodal_gen/runtime/layers/lora/linear.py
# 关键变更:默认 FP32 合并 + 更新 strength 到列表def _should_merge_in_fp32(self, lora_list: list[LoRAWeightEntry]) -> bool:
    # 环境变量默认值从 "0" 改为 "1":默认启用 FP32 合并
    if os.getenv("SGLANG_DIFFUSION_LORA_MERGE_FP32", "1") != "1":
        return False
    # 对 distilled-lora 路径仍然不使用 FP32(防止精度问题)
    for _, _, lora_path, _, _, _ in lora_list:
        if lora_path and "distilled-lora" in lora_path.lower():
            return False
    return True@torch.no_grad()
def merge_lora_weights(self, strength: float | None = None) -> None:
    if strength is not None:
        self.strength = strength
        # 新增:将新 strength 应用到 lora_weights_list 的每个条目,保证强度一致
        if self.lora_weights_list:
            self.lora_weights_list = [
                (lora_A, lora_B, lora_path, strength, lora_rank, lora_alpha)
                for (lora_A, lora_B, lora_path, _, lora_rank, lora_alpha)
                in self.lora_weights_list
            ]
    # ... 后续合并逻辑不变 ...
python/sglang/multimodal_gen/test/server/test_server_common.py test-coverage

新增 LoRA 一致性验证方法,增强端到端测试覆盖。

# python/sglang/multimodal_gen/test/server/test_server_common.py
# 新增:LoRA 一致性验证辅助方法def _validate_lora_consistency(
    self, case: DiffusionTestCase, content: bytes, operation: str
) -> None:
    """
    验证 LoRA 操作后的生成内容与 ground truth 一致性。
    Args:
        case: 测试用例配置,包含一致性检查开关
        content: 生成的图像或视频字节
        operation: 操作描述,用于日志
    """
    if not case.run_consistency_check:
        logger.info(
            "[LoRA Consistency] Skipping %s consistency for %s: disabled for case",
            operation,
            case.id,
        )
        return
​
    logger.info(
        "[LoRA Consistency] Validating %s output for %s", operation, case.id
    )
    self._validate_consistency(case, content)

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

  • FP32 合并默认启用:增加显存占用和计算开销,但对 diffusion 推理整体影响较小(仅合并阶段),精度提升是值得的。
  • 一致性测试依赖远程 GT 仓库:新的 _validate_lora_consistency 会通过网络下载 GT 文件,如果网络不稳定或 GT 版本未同步,可能导致测试失败。
  • lora_alpha 文件缺失adapter_config.json 可能不存在或格式异常,当前代码仅做了存在性检查,缺少异常处理,但 fallback 到 rank 机制可缓解。
  • 路径复用逻辑set_lora 中依赖 loaded_adapter_paths 缓存,若缓存被错误清除可能引发后续错误。
  • 用户:使用 diffusion LoRA API 的用户将获得更精确的权重合并和更稳定的多 adapter 切换体验。
  • 系统:默认 FP32 合并增加了少量计算负荷,但提升了生成质量;新增的一致性校验在 CI 中运行,不影响线上服务。
  • 团队:新增的端到端测试覆盖了 LoRA 核心 API 路径,降低回归风险,便于后续重构。
FP32 合并默认启用可能影响性能 一致性测试依赖外部 GT 仓库 lora_alpha 文件可能缺失

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论