Prhub

#25661 [diffusion] model: support FLUX.2-klein-base

原始 PR 作者 alex0dd 合并时间 2026-05-22 11:24 文件变更 10 提交数 4 评论 6 代码增减 +147 / -3

执行摘要

支持 FLUX.2-klein-base 未蒸馏模型,启用 CFG 和 negative prompts

此前 FLUX.2-klein-base 通过蒸馏版 Klein 间接支持,但蒸馏版不支持 negative prompts 和 CFG parallelism。本 PR 直接支持该模型,使其功能完备。

该 PR 实现清晰,适合快速合并。建议关注其后的扩散模型 PR 以了解 FLUX 系列支持的演进。

讨论亮点

维护者 mickqian 要求作者提供结果与官方 diffusers 输出的对比,作者在 Issue 评论中附带了多组图片对比,显示输出质量一致。之后 mickqian 表示“fantastic job, cheers”并批准合并。无其他争议。

实现拆解

  1. 新增 PipelineConfig:在 python/sglang/multimodal_gen/configs/pipeline_configs/flux.py 中添加 Flux2KleinBasePipelineConfig,继承自 Flux2KleinPipelineConfig,设置 should_use_guidance=True 并实现 prepare_neg_cond_kwargs 方法,为 CFG 路径构建 freqs_cis
  2. 新增 SamplingParams:在 python/sglang/multimodal_gen/configs/sample/flux.py 中添加 Flux2KleinBaseSamplingParams,设置默认 num_inference_steps=50guidance_scale=4.0negative_prompt=""
  3. 模型注册:在 python/sglang/multimodal_gen/registry.py 中导入新类,新增 register_configs 调用,并调整原始 Klein 的 detector 逻辑,增加 and "base" not in hf_id.lower() 以区分 base 和非 base 变体。
  4. 放宽 negative_prompt 验证:在 python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py 中,将 CFG 的 negative_prompt 验证从 V.string_not_none(x) 改为 V.string_not_none(x) or isinstance(x, str),允许空字符串(klein-base 的默认空提示)。
  5. 测试与性能基线:在 python/sglang/multimodal_gen/test/server/gpu_cases.py 中添加了 1-GPU CI 测试用例,并在 perf_baselines.json 中记录实测性能数据。同时更新了文档中的兼容性矩阵。
文件 模块 状态 重要度
python/sglang/multimodal_gen/configs/pipeline_configs/flux.py 扩散模型 modified 7.24
python/sglang/multimodal_gen/configs/sample/flux.py 扩散模型 modified 6.35
python/sglang/multimodal_gen/registry.py 注册中心 modified 6.29
python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py 文本编码 modified 5.11
python/sglang/multimodal_gen/test/server/perf_baselines.json 性能基线 modified 4.91
python/sglang/multimodal_gen/test/server/gpu_cases.py GPU 测试 modified 4.27
docs_new/docs/sglang-diffusion/compatibility_matrix.mdx 文档 modified 2.52
docs_new/docs/sglang-diffusion/dynamic_batching.mdx 文档 modified 2.14
.github/workflows/diffusion-ci-gt-gen.yml CI modified 2.38
docs/diffusion/compatibility_matrix.md 文档 modified 1.32

关键符号

Flux2KleinBasePipelineConfig.prepare_neg_cond_kwargs

关键源码片段

python/sglang/multimodal_gen/configs/pipeline_configs/flux.py core-logic

核心:新增 Flux2KleinBasePipelineConfig 及 prepare_neg_cond_kwargs 方法,控制 CFG 行为

@dataclass
class Flux2KleinBasePipelineConfig(Flux2KleinPipelineConfig):
    # Undistilled Klein base model, with guidance embeddings
    should_use_guidance: bool = True
​
    def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):
        # 获取负 prompt 的文本序列长度
        txt_seq_lens = self.require_text_seq_lens(
            batch,
            0,
            negative=True,
            expected_batch_size=batch.negative_prompt_embeds[0].shape[0],
        )
        # 为负 prompt 构建 rotary embedding 频率,用于 CFG 并行
        return {
            "freqs_cis": self.get_freqs_cis(
                batch.negative_prompt_embeds[0],
                batch.width,
                batch.height,
                device,
                rotary_emb,
                batch,
                txt_seq_lens,
            )
        }
python/sglang/multimodal_gen/configs/sample/flux.py core-logic

定义 Flux2KleinBaseSamplingParams,指定默认推理步数、guidance 和负提示

@dataclass
class Flux2KleinBaseSamplingParams(FluxSamplingParams):
    # Klein-base 是未蒸馏版本,需要 50 步和较大的 guidance
    num_inference_steps: int = 50
    guidance_scale: float = 4.0
    negative_prompt: str = "" # 允许空字符串,CFG 验证通过
python/sglang/multimodal_gen/registry.py core-logic

注册新模型并调整 detector 逻辑,确保 base 和非 base 正确分流

# 导入新配置
from sglang.multimodal_gen.configs.pipeline_configs.flux import (
    Flux2KleinBasePipelineConfig,
    Flux2KleinPipelineConfig,
    Flux2PipelineConfig,
)
from sglang.multimodal_gen.configs.sample.flux import (
    Flux2KleinBaseSamplingParams,
    Flux2KleinSamplingParams,
    Flux2SamplingParams,
    FluxSamplingParams,
)# 注册蒸馏版 Klein,明确排除 base
register_configs(
    sampling_param_cls=Flux2KleinSamplingParams,
    pipeline_config_cls=Flux2KleinPipelineConfig,
    hf_model_paths=[
        "black-forest-labs/FLUX.2-klein-4B",
        "black-forest-labs/FLUX.2-klein-9B",
    ],
    model_detectors=[
        lambda hf_id: (
            "flux.2-klein" in hf_id.lower() or "flux2-klein" in hf_id.lower()
        )
        and "base" not in hf_id.lower() # 排除 base 变体
    ],
)# 注册未蒸馏 Klein-base
register_configs(
    sampling_param_cls=Flux2KleinBaseSamplingParams,
    pipeline_config_cls=Flux2KleinBasePipelineConfig,
    hf_model_paths=[
        "black-forest-labs/FLUX.2-klein-base-4B",
        "black-forest-labs/FLUX.2-klein-base-9B",
    ],
    model_detectors=[
        lambda hf_id: (
            "flux.2-klein" in hf_id.lower() or "flux2-klein" in hf_id.lower()
        )
        and "base" in hf_id.lower() # 仅匹配 base
    ],
)

评论区精华

结果验证 正确性

维护者 mickqian 要求作者提供与官方 diffusers 输出的对比结果。

结论:作者提供了多组图片对比,显示输出质量一致,mickqian 表示满意。 · 已解决

风险与影响

风险较低。主要变更在配置层和注册层,不涉及模型 forward 逻辑或 kernel 修改。潜在风险:模型 detector 逻辑修改可能影响非 base 版 Klein 的匹配(已通过精确排除 "base" 字符串控制);empty string 验证放宽可能使其他模型传入空提示时通过验证,但这与 CFG 设计一致。

对用户:新增对 FLUX.2-klein-base 的官方支持,可正常使用 CFG 和 negative prompts。对系统:无性能回归,新增测试用例和基线。对团队:约 150 行新增代码,维护成本低。

新模型注册 测试覆盖完整 性能基线更新

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论