Prhub

#22625 [diffusion] model: support JoyAI-Image-Edit

原始 PR 作者 lahmuller 合并时间 2026-05-02 14:08 文件变更 17 提交数 13 评论 22 代码增减 +2344 / -12

执行摘要

支持 JoyAI-Image-Edit 图像编辑模型

PR 由 JoyAI Team 提出,希望基于 SGLang 扩散框架提供图像编辑能力。JoyAI-Image 是一个统一的多模态基础模型,结合了 8B MLLM 和 16B 多模态扩散 Transformer。此 PR 使 SGLang 能够加载并运行 JoyAI-Image-Edit 模型进行指令引导的图像编辑推理。

推荐关注此 PR 的设计模式:它展示了如何通过配置驱动的方式将新扩散模型集成到 SGLang 框架中,特别是通过管道配置的 postprocess_text_funcs 实现后处理泛化。对模型集成者有参考价值。建议精读 qwen3vl.py 的编码器实现和 joy_image.py(runtime)的 DiT 双流块实现。

讨论亮点

Review 中最关键的讨论是 BBuf 提出的两点:

1) 通用化后处理导致 Qwen-Image-Edit 的 drop_idx 从 64 变为默认 34,可能改变现有模型的条件 token;
2) latents[0] 只取第一个输出的处理会导致批量请求时丢失其他输出。作者均已修复。此外,gemini-code-assist 建议将 WanVAEConfig.post_init 合并到 __post_init__ 并避免 bucket_configs 在请求处理中初始化,也已修复。变量 shadowing 问题(input_idsconfig)在审查时未明确修复。

实现拆解

  1. 注册与配置:在 multimodal_gen/registry.py 中添加 JoyAI 模型路径检测和配置映射,新增 JoyImageEditPipelineConfigJoyImageEditSamplingParams
  2. Qwen3-VL 文本编码器:新增 qwen3vl.py 配置文件(Qwen3VLArchConfig/Qwen3VLConfig)和运行时模型文件(runtime/models/encoders/qwen3vl.py),实现了完整的 Qwen3VLTextAttentionQwen3VLTextModelQwen3VLModel 等模块,用于文本和视觉条件编码。
  3. JoyImage DiT 模型:新增 joy_image.py 配置文件(JoyImageArchConfig/JoyImageDiTConfig)和运行时模型文件(runtime/models/dits/joy_image.py),包含 MMDoubleStreamBlockModulateWanJoyTransformer3DModel 等,支持序列并行和权重名称映射。
  4. 编辑管道:新增 JoyImageEditPipeline 类,通过 create_pipeline_stages 编排标准 T2I 阶段。JoyImageEditPipelineConfig 设置默认采样参数(guidance_scale=4.0, num_inference_steps=40, num_frames=1)。
  5. 框架泛化:将图像编码阶段的后处理函数抽象为管道可配置的 postprocess_text_funcs,保留 Qwen-Image-Edit 原有的 drop_idx=64 行为。在 WanVAEConfig 中添加 get_vae_scale_factor()post_init() 统一缩放因子获取。
  6. 测试:新增 JoyAI 图像编辑 1 GPU 扩散 CI 烟雾测试,包含性能基准数据。
文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/models/encoders/qwen3vl.py 编码器 added 9.36
python/sglang/multimodal_gen/runtime/models/dits/joy_image.py DiT 模型 added 9.36
python/sglang/multimodal_gen/configs/pipeline_configs/joy_image.py 管道配置 added 9.08
python/sglang/multimodal_gen/configs/models/encoders/qwen3vl.py 编码器配置 added 8.77
python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py 图像编码阶段 modified 7.32

关键符号

Qwen3VLTextAttention.forward MMDoubleStreamBlock.forward JoyImageEditPipeline.create_pipeline_stages joy_image_postprocess_text JoyImageEditPipelineConfig.__post_init__

关键源码片段

python/sglang/multimodal_gen/runtime/models/encoders/qwen3vl.py data-contract

新增的 Qwen3-VL 文本编码器运行时实现,包含 Attention、MLP、DecoderLayer、Model 等完整模块,是 JoyAI 模型理解文本和图像条件的关键组件。

class Qwen3VLTextAttention(nn.Module):
    # 使用 LocalAttention 替代原生 FlashAttention,兼容不同平台
    def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = self.head_dim ** -0.5
        self.attention_dropout = config.attention_dropout
        self.is_causal = True
        self.num_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        # 标准线性投影
        self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
        # Q/K 各自独立应用 RMSNorm(区别于共用 norm 的做法)
        self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        # 本地注意力层,支持 FA 和 torch_sdpa 后端
        self.attn = LocalAttention(
            num_heads=self.num_heads,
            head_size=self.head_dim,
            num_kv_heads=self.num_key_value_heads,
            softmax_scale=self.scaling,
            causal=True,
            supported_attention_backends=(AttentionBackendEnum.FA, AttentionBackendEnum.TORCH_SDPA),
        )
​
    def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_values=None, cache_position=None, **kwargs):
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)
        # QKV 投影:Q 和 K 各自带归一化
        query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
        key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        # 应用旋转位置编码(RoPE)
        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
        # 可选:更新 past_key_values 缓存
        if past_key_values is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
        # 调用 LocalAttention 执行注意力计算
        attn_output = self.attn(query_states, key_states, value_states, attn_mask=attention_mask)
        # 输出投影
        attn_output = attn_output.transpose(1, 2).contiguous().reshape(hidden_states.shape)
        attn_output = self.o_proj(attn_output)
        return attn_output, None
python/sglang/multimodal_gen/runtime/models/dits/joy_image.py data-contract

新增的 JoyImage DiT 模型运行时实现,包含双流块 ModulateWan、MMDoubleStreamBlock、JoyTransformer3DModel 等,是生成图像编辑结果的核心扩散骨干。

class MMDoubleStreamBlock(nn.Module):
    # 双流块:图像流和文本流各自独立进行调制、注意力、MLP,然后通过门控融合
    def forward(self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor):
        # 从条件向量中切分出调制参数(6 个因子:缩放、平移、门控)
        img_mod1, img_mod2 = self.img_mod(vec).chunk(2, dim=1)
        txt_mod1, txt_mod2 = self.txt_mod(vec).chunk(2, dim=1)
        # 图像流:归一化 + 双流注意力
        img_norm1 = self.fused_modulate_img_norm1(img, img_mod1[:, :3])
        img_attn_out, _ = self.img_attn(img_norm1, pe)
        img = fused_add_gate(img, img_attn_out, img_mod1[:, 3])
        # 文本流:归一化 + 双流注意力
        txt_norm1 = self.fused_modulate_txt_norm1(txt, txt_mod1[:, :3])
        txt_attn_out, _ = self.txt_attn(txt_norm1, pe)
        txt = fused_add_gate(txt, txt_attn_out, txt_mod1[:, 3])
        # 图像 MLP
        img_norm2 = self.fused_modulate_img_norm2(img, img_mod2[:, :3])
        img_mlp_out = self.img_mlp(img_norm2)
        img = fused_add_gate(img, img_mlp_out, img_mod2[:, 3])
        # 文本 MLP
        txt_norm2 = self.fused_modulate_txt_norm2(txt, txt_mod2[:, :3])
        txt_mlp_out = self.txt_mlp(txt_norm2)
        txt = fused_add_gate(txt, txt_mlp_out, txt_mod2[:, 3])
        return img, txt
python/sglang/multimodal_gen/configs/pipeline_configs/joy_image.py dependency-wiring

新增的 JoyImage 编辑管道配置,定义了任务类型、组件配置、采样默认值和桶生成策略,是模型集成的入口配置。

@dataclass
class JoyImageEditPipelineConfig(ImagePipelineConfig):
    task_type: ModelTaskType = ModelTaskType.I2I # 图像到图像编辑
    dit_config: DiTConfig = field(default_factory=JoyImageDiTConfig)
    vae_config: VAEConfig = field(default_factory=WanVAEConfig)
    vae_tiling: bool = False # 编辑场景无需 VAE 分片
    vae_sp: bool = False
    flow_shift: float = 1.5
    # 使用 Qwen3-VL 作为文本 + 视觉编码器
    text_encoder_configs: tuple[EncoderConfig, ...] = field(default_factory=lambda: (Qwen3VLConfig(),))
    enable_torch_compile: bool = False
    # 全 bf16 精度
    precision: str = "bf16"
    vae_precision: str = "bf16"
    text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("bf16",))
    # 后处理函数可配置,允许不同编码器使用不同截断策略
    postprocess_text_funcs: tuple[Callable, ...] = field(default_factory=lambda: (joy_image_postprocess_text,))
    prioritize_frame_matching: bool = True
    bucket_configs: list[tuple[int, int, int, int, int]] = field(init=False)
​
    def __post_init__(self):
        # 初始化桶配置:1024 分辨率,单帧,批量大小 8,最多 6 项
        self.bucket_configs = self.generate_video_image_bucket(
            basesize=1024, min_temporal=1, max_temporal=1,
            bs_img=8, bs_vid=4, bs_mimg=8, min_items=1, max_items=6,
        )

评论区精华

通用化后处理导致 Qwen-Image-Edit drop_idx 回归 正确性

BBuf 指出将 image_edit 后处理通用化后,Qwen-Image-Edit 的 drop_idx 从 64 降为默认 34,会改变现有模型的条件 token。

结论:作者恢复 edit 特定分支,使用 drop_idx=64,问题已解决。 · 已解决

latents[0] 选择导致批量优化丢失 正确性

BBuf 提示 JoyImageEditPipelineConfig 的 post_denoising_loop 中 latents[0] 仅取第一个输出,批量请求会丢弃其余输出。

结论:作者对齐了条件 latents 的批量维度,并在多处添加严格维度检查,问题已解决。 · 已解决

WanVAEConfig.post_init 冗余设计 设计

gemini-code-assist 建议将 post_init 逻辑合并到 __post_init__,以遵循标准 dataclass 模式。

结论:作者已将逻辑迁移到 __post_init__,问题已解决。 · 已解决

风险与影响

主要风险在于:

1) 回归风险:对 image_encoding.py 的泛化改动了原有 Qwen-Image-Edit 路径,虽经修复,但可能仍有未覆盖的边角情况。
2) 依赖上游权重格式:JoyAI-Image-Edit 模型权重名称仍在变化(依赖 huggingface/diffusers#13444),合并后如权重变动可能导致加载失败。
3) 测试覆盖率有限:仅添加了烟雾生成测试,未包含单元测试或一致性校验,可能存在未发现的逻辑错误。
4) 批量请求支持虽已修补,但未经过充分验证。

对用户,新增了 JoyAI-Image-Edit 模型支持,可以用于指令引导的图像编辑。对系统,新增约 2300 行代码,主要是新模型实现和配置,对现有框架核心路径改动较小,不引入性能回归。对团队,需要维护新模型代码,并跟踪上游权重变化。

泛化改动影响现有路径 依赖上游权重格式 测试覆盖率有限 批量支持验证不足

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论