Prhub

#25305 [diffusion] Fix Z-Image Cache-DiT sequence-parallel override

原始 PR 作者 qimcis 合并时间 2026-05-15 13:25 文件变更 1 提交数 2 评论 7 代码增减 +7 / -2

执行摘要

修复 Cache-DiT 下 Z-Image 的 sequence parallel 覆盖问题

关联 Issue #25254 报告了当启用 Cache-DiT(SGLANG_CACHE_DIT_ENABLED=true)时,Z-Image 模型在推理时崩溃,返回 Internal Server Error。 PR body 描述了根本原因是 Cache-DiT 包装了 transformer blocks,导致 layer.attention 不再是直接的 ZImageAttention 实例,原先直接修改 layer.attention.attn.skip_sequence_parallel 的代码失效。

此 PR 值得精读,尤其是其修复模式——通过参数传递替代直接属性修改,是一种更稳健的设计。对于涉及模型包装和参数覆盖的场景有参考价值。

讨论亮点

PR 无 review 评论讨论,由 mickqian 直接批准。

实现拆解

  1. ZImageAttention.forward 中新增参数:在文件 python/sglang/multimodal_gen/runtime/models/dits/zimage.pyZImageAttention.forward 方法签名中添加 skip_sequence_parallel_override: bool = False 参数。
  2. ZImageAttention.forward 内部传递参数:当调用 self.attn(...) 时,将新增的 skip_sequence_parallel_override 参数传递给 USPAttention 的 forward 方法。
  3. ZImageTransformerBlock.forward 中新增并传递参数:在 ZImageTransformerBlock.forward 方法签名中添加 skip_sequence_parallel_override 参数,并在调用 self.attention(...) 时传递该参数。
  4. ZImageFinalLayer.forward 中新增并传递参数:类似地,在 ZImageFinalLayer.forward 方法签名中添加并传递该参数。
  5. 修改主循环调用点:在 ZImageModel.forward 中,移除原先在每个 step 中直接设置 layer.attention.attn.skip_sequence_parallel = use_full_unified_sequence 的循环,改为在调用 layer(...) 时通过 skip_sequence_parallel_override=use_full_unified_sequence 参数传递。
  6. 无测试文件变更:本次修改未包含测试文件,仅源码变更。
文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/models/dits/zimage.py 扩散模型 modified 5.92

关键符号

ZImageAttention.forward ZImageTransformerBlock.forward ZImageFinalLayer.forward ZImageModel.forward

关键源码片段

python/sglang/multimodal_gen/runtime/models/dits/zimage.py core-logic

包含所有变更:在 ZImageAttention、ZImageTransformerBlock、ZImageFinalLayer 的 forward 方法中新增 skip_sequence_parallel_override 参数传递,并修改主循环调用方式。

# python/sglang/multimodal_gen/runtime/models/dits/zimage.py
# 修改 ZImageAttention.forward:新增 skip_sequence_parallel_override 参数并传递给 self.attn
class ZImageAttention(nn.Module):
    def forward(
        self,
        hidden_states: torch.Tensor,
        freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        num_replicated_prefix: int = 0,
        num_replicated_suffix: int = 0,
        skip_sequence_parallel_override: bool = False, # 新增参数,用于替代直接修改 layer.attention.attn.skip_sequence_parallel
    ):
        # ... (q/k/v 计算和形状变换 ) ...
        if num_replicated_suffix > 0:
            # ... (ulysses_attn 分支 ) ...
        else:
            hidden_states = self.attn(
                q,
                k,
                v,
                num_replicated_prefix=num_replicated_prefix,
                num_replicated_suffix=num_replicated_suffix,
                skip_sequence_parallel_override=skip_sequence_parallel_override, # 传递参数
            )
        # ...# 修改 ZImageTransformerBlock.forward:传递 skip_sequence_parallel_override
class ZImageTransformerBlock(nn.Module):
    def forward(
        self,
        hidden_states: torch.Tensor,
        freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        adaln_input: Optional[torch.Tensor] = None,
        num_replicated_prefix: int = 0,
        num_replicated_suffix: int = 0,
        skip_sequence_parallel_override: bool = False, # 新增参数
    ):
        # ...
        attn_out = self.attention(
            hidden_states,
            freqs_cis=freqs_cis,
            num_replicated_prefix=num_replicated_prefix,
            num_replicated_suffix=num_replicated_suffix,
            skip_sequence_parallel_override=skip_sequence_parallel_override, # 传递
        )
        # ...# 修改 ZImageModel.forward:移除直接属性修改,改用参数传递
class ZImageModel(nn.Module):
    def forward(self, ...):
        # ...
        for layer in self.layers:
            unified = layer(
                unified,
                unified_freqs_cis,
                adaln_input,
                num_replicated_suffix=num_replicated_suffix,
                skip_sequence_parallel_override=use_full_unified_sequence, # 通过参数传递
            )
        # ...

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

  1. 回归风险低:修改集中于单个文件 zimage.py,且改动量小(+7/-2),逻辑清晰,不易引入回归。
  2. 影响范围受限:仅影响启用 Cache-DiT 的 Z-Image 模型推理路径。
  3. 缺少测试覆盖:本次修改未包含对应测试用例,未来重构或修改可能缺乏保护。
  • 用户影响:修复了 Z-Image 模型在 Cache-DiT 下的崩溃问题,用户可正常使用图像生成功能。
  • 系统影响:无性能影响,改动仅在推理路径中添加了一个布尔参数传递。
  • 团队影响:为后续 Cache-DiT 和其他 diffusion 模型的集成提供了更安全的参数传递模式。
缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论