Prhub

#23538 [NPU] Fix Z-Image negative-branch rotary embeddings for CFG

原始 PR 作者 gxxx-hum 合并时间 2026-05-03 21:18 文件变更 2 提交数 9 评论 8 代码增减 +51 / -1

执行摘要

修复 Z-Image 负提示旋转嵌入使用正提示长度的 bug

Z-Image在使用CFG生成图像时,负分支的旋转位置编码形状错误(32 vs 192),导致Tensor尺寸不匹配的运行时错误。该Bug由PR body中提供的堆栈跟踪和复现步骤明确报告。

该PR值得审阅以理解扩散模型中CFG分支处理的常见陷阱;设计简单明了,适合作为bugfix范例。

讨论亮点

审查者OrangeRedeng要求添加CI测试以避免未来回归,贡献者gxxx-hum同意并提交了测试。合并者ping1jing2指出GPU CI出现另一个错误(由#23625引起)并确认NPU CI正常后合并。

实现拆解

  1. 修改prepare_neg_cond_kwargs方法python/sglang/multimodal_gen/configs/pipeline_configs/zimage.py,第363-383行):新增prompt_embeds变量,优先使用batch.negative_prompt_embeds[0](若存在),否则回退到batch.prompt_embeds[0]。将get_freqs_cis的第一个参数从此前的batch.prompt_embeds[0]替换为prompt_embeds,确保负分支使用正确的嵌入长度。
  2. 新增单元测试python/sglang/multimodal_gen/test/unit/test_zimage_pipeline_config.py,全文件):添加TestZImagePipelineConfig.test_zimage_negative_prompt_rotary_embeddings_use_negative_prompt_len方法,模拟不同正/负序列长度(19 vs 45),断言prepare_neg_cond_kwargs返回的freqs_cis中位置ID的形状与负提示序列长度对齐,验证修复正确性。
文件 模块 状态 重要度
python/sglang/multimodal_gen/configs/pipeline_configs/zimage.py 扩散配置 modified 5.8
python/sglang/multimodal_gen/test/unit/test_zimage_pipeline_config.py 测试 added 6.26

关键符号

prepare_neg_cond_kwargs get_freqs_cis test_zimage_negative_prompt_rotary_embeddings_use_negative_prompt_len

关键源码片段

python/sglang/multimodal_gen/configs/pipeline_configs/zimage.py core-logic

核心修复:修改 `prepare_neg_cond_kwargs` 以使用负提示嵌入的长度构建 RoPE。

# python/sglang/multimodal_gen/configs/pipeline_configs/zimage.py
class ZImagePipelineConfig:
    def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):
        # 修复:使用负提示嵌入(如果存在),否则回退到正提示嵌入
        prompt_embeds = (
            batch.negative_prompt_embeds[0]
            if batch.negative_prompt_embeds is not None
            else batch.prompt_embeds[0]
        )
        return {
            "freqs_cis": self.get_freqs_cis(
                prompt_embeds, # 之前这里错误地使用了 batch.prompt_embeds[0]
                batch.width,
                batch.height,
                device,
                rotary_emb,
                batch,
            ),
            "image_seq_len_target": (
                self._get_zimage_sp_plan(batch)["img_seq_target"]
                if get_sp_world_size() > 1
                else None
            ),
        }
python/sglang/multimodal_gen/test/unit/test_zimage_pipeline_config.py test-coverage

新增单元测试验证修复,确保负分支使用负提示长度。

# python/sglang/multimodal_gen/test/unit/test_zimage_pipeline_config.py
import unittest
from types import SimpleNamespace
from unittest.mock import patchimport torchfrom sglang.multimodal_gen.configs.pipeline_configs.zimage import ZImagePipelineConfig
​
​
class TestZImagePipelineConfig(unittest.TestCase):
    @patch("sglang.multimodal_gen.configs.pipeline_configs.zimage.get_sp_world_size")
    def test_zimage_negative_prompt_rotary_embeddings_use_negative_prompt_len(
        self, mock_get_sp_world_size
    ) -> None:
        """Negative CFG branch should build RoPE positions from negative prompt embeds."""
        mock_get_sp_world_size.return_value = 1
​
        config = ZImagePipelineConfig()
        pos_seq_len = 19
        neg_seq_len = 45
        batch = SimpleNamespace(
            prompt_embeds=[torch.ones(pos_seq_len, 2560)],
            negative_prompt_embeds=[torch.ones(neg_seq_len, 2560)],
            height=16,
            width=16,
        )
​
        def rotary_emb(pos_ids):
            return pos_ids
​
        neg_kwargs = config.prepare_neg_cond_kwargs(
            batch=batch,
            device=torch.device("cpu"),
            rotary_emb=rotary_emb,
            dtype=torch.float32,
        )
​
        cap_pos_ids, image_pos_ids = neg_kwargs["freqs_cis"]
        neg_cap_padded_len = 64
        # 断言:caption 位置 ID 的形状应为 (64, 3),基于负提示填充长度
        self.assertEqual(cap_pos_ids.shape, (neg_cap_padded_len, 3))
        # 断言:第一个图像位置 ID 正确反映了填充偏移
        self.assertEqual(image_pos_ids[0].tolist(), [neg_cap_padded_len + 1, 0, 0])
​
​
if __name__ == "__main__":
    unittest.main()

评论区精华

添加 Z-Image CI 测试 测试

OrangeRedeng 建议添加 CI 测试以避免未来回归;gxxx-hum 同意并在 PR 中新增了单元测试。

结论:但该 PR 仅添加了单元测试,未集成到 CI 流水线;测试在 Python 端已覆盖。 · 已解决

GPU CI 错误 other

ping1jing2 指出 GPU CI 存在另一个错误(#23625),但 NPU CI 通过。

结论:确认无关后合并。 · 已解决

风险与影响

风险极低:变更仅影响Z-Image模型的负分支RoPE构造,且逻辑简单(首选负提示嵌入,降级到正提示)。单元测试覆盖了核心场景,不会影响其他模型或正常分支。GPU CI的失败与此次PR无关。

影响范围仅限于使用Z-Image模型且启用CFG(Classifier-Free Guidance)的用户。修复后,具有负提示的生成将正确工作,消除尺寸不匹配错误。无统计效果或兼容性问题。

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论