Prhub

#25168 [diffusion] Role-based component loading and stage affinity

原始 PR 作者 FredHuang99 合并时间 2026-05-22 18:50 文件变更 12 提交数 2 评论 14 代码增减 +991 / -83

执行摘要

扩散管道引入角色感知的组件加载与阶段亲和性

使分解后的 diffusion 角色选择显式化,让 encoder、denoiser 和 decoder 实例只加载和执行它们所需的模块,从而减少 GPU 内存占用并提高组件边界清晰度。

该 PR 值得精读,它是 diffusion 管道迈向分解部署的关键环节,角色分离架构与 stage 亲和性模式对后续多 GPU 部署和内存优化有重要借鉴意义。建议重点关注 _get_extra_allowed_modules_for_role 的设计权衡(中央 vs 分散)以及 stage 创建时的角色跳过逻辑,同时注意未来可能需要的白名单重构。

讨论亮点

Review 中核心讨论点:

  • 中央配置设计权衡:gemini-code-assist 建议将 _get_extra_allowed_modules_for_role 中的白名单分散到各管道类,减少维护瓶颈并遵循开闭原则。作者未直接回应,PR 合入时保持中央配置,该建议暂未采纳。
  • LTX2RefinementStage 兼容性:ShangmingCai 询问改动是否影响 LTX2RefinementStage。作者发现因蛇形命名推断问题导致影响,修复了 stage 名称推断逻辑并添加回归测试。
  • dtype 推断错误处理:ShangmingCai 建议将 _find_conditioner_dtype 中的 pass 改为 warningraise。作者采纳并改为 logger.warning,同时保留回退到 float32 的非致命行为。
  • 代码风格:ShangmingCai 建议将 Qwen 中 text_encoder 两次 to() 调用合并为一次,作者接受并修改。

实现拆解

  1. 扩展角色映射与过滤逻辑: 在 roles.py 中增加角色别名(如 denoisingDENOISER),扩充模块到角色的映射前缀(如 vision_language_encoder→ENCODER,video_dit→DENOISER),并重构 filter_modules_for_role 函数,新增 extra_allowed_modules 参数以支持白名单机制,移除了原来 Encoder 直接放行 Decoder 模块的特殊逻辑。

  2. 基类角色集成: 在 ComposedPipelineBase.__init__ 中调用 validate_disagg_role 进行角色验证,并在模块加载前根据角色过滤 _required_config_modules;新增 _get_extra_allowed_modules_for_role 方法,集中管理各管道在特定角色下需要额外保留的模块(如 Encoder 角色下 Flux2 需要保留 vae)。

  3. Stage 亲和性机制: 为 Hunyuan3DShapeBeforeDenoisingStageMOVADenoisingStageMOVADecodingStageHeliosChunkedDenoisingStage 等 stage 添加 role_affinity 属性,指示该 stage 应由哪个角色执行;基类在创建 stage 时通过 _should_add_stage_for_role 跳过不匹配的 stage。

  4. 管道适配与动态 dtype: 更新 Hunyuan3DShapeBeforeDenoisingStage 构造函数,从依赖 vae/model 改为接收 latent_shape/guidance_embed 等标量,实现 _find_conditioner_dtype_resolve_runtime_dtype 动态推断 dtype;更新 QwenImageLayeredBeforeDenoisingStage 允许外部传入 text_encoder,新增 _resolve_text_encoder_dtype 统一 dtype 推断。

  5. 测试与配套: 新增 test_disagg_roles.py 单元测试文件,覆盖角色枚举解析、模块过滤、特定管道的白名单逻辑、Hunyuan3D 形状 stage 的 dtype 推断以及 LTX2RefinementStage 名称兼容性;其他管道(如 qwen_image.py, mova_pipeline.py)同步更新构造函数参数。

文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py 管道基类 modified 8.99
python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_shape.py 3D 形状阶段 modified 8.16
python/sglang/multimodal_gen/test/unit/test_disagg_roles.py 角色测试 added 7.76

关键符号

validate_disagg_role _get_extra_allowed_modules_for_role _should_add_stage_for_role add_stage_factory create_stage _find_conditioner_dtype _resolve_runtime_dtype _resolve_text_encoder_dtype filter_modules_for_role get_module_role role_affinity

关键源码片段

python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py core-logic

核心变更文件,添加角色验证、额外模块白名单、阶段跳过逻辑,重构初始化流程以支持角色过滤,是整个角色分离架构的中心枢纽。

# 摘自 composed_pipeline_base.py —— 角色初始化与白名单核心逻辑class ComposedPipelineBase:
​
    def __init__(self, ...):
        self._disagg_role = server_args.disagg_role
        self.validate_disagg_role(self._disagg_role) # 验证角色是否被支持
​
        # 复制并准备基础配置模块列表
        base_required = (
            required_config_modules
            if required_config_modules is not None
            else self._required_config_modules
        )
        self._required_config_modules = list(base_required)
        self._extra_config_module_map = dict(self._extra_config_module_map)
​
        # 根据角色过滤配置模块,避免加载不需要的模块
        if self._disagg_role != RoleType.MONOLITHIC:
            original = list(self._required_config_modules)
            task_name = self.server_args.pipeline_config.task_type.name.lower()
            self._required_config_modules = filter_modules_for_role(
                self._required_config_modules,
                self._disagg_role,
                extra_allowed_modules=self._get_extra_allowed_modules_for_role(
                    self._disagg_role, task_name
                ),
            )
            skipped = set(original) - set(self._required_config_modules)
            if skipped:
                logger.info(
                    "Disagg role=%s: skipping modules %s",
                    self._disagg_role.value,
                    sorted(skipped),
                )
​
    def _get_extra_allowed_modules_for_role(
        self, role: RoleType, task_name: str
    ) -> set[str]:
        # 集中定义各管道在不同角色下需要额外保留的模块
        # 尽管违反了单一职责,但当前便于审查和维护
        role_to_pipeline_modules: dict[RoleType, dict[str, set[str]]] = {
            RoleType.ENCODER: {
                "Flux2Pipeline": {"vae"},
                "QwenImageEditPipeline": {"vae"},
                "QwenImageLayeredPipeline": {"vae", "transformer"},
                "GlmImagePipeline": {"vae", "transformer"},
                "WanImageToVideoPipeline": {"vae"},
                "MOVAPipeline": {"video_vae", "audio_vae"},
            },
            RoleType.DENOISER: {},
            RoleType.DECODER: {},
        }
        extra_allowed = set(
            role_to_pipeline_modules.get(role, {}).get(self.pipeline_name, set())
        )
​
        # Denoiser 在处理 ti2v 任务时也需要 vae
        if role == RoleType.DENOISER and task_name == "ti2v":
            if self.pipeline_name in {
                "WanImageToVideoPipeline",
                "WanImageToVideoDmdPipeline",
            }:
                extra_allowed.add("vae")
            elif self.pipeline_name == "LTX2Pipeline":
                extra_allowed.update({"vae", "audio_vae"})
​
        return extra_allowed

评论区精华

额外允许模块中央配置的设计权衡 设计

gemini-code-assist 建议将 `_get_extra_allowed_modules_for_role` 中的白名单分散到各个管道类,以减少维护瓶颈并遵循开闭原则。

结论:作者未直接回应,PR 合入时保持中央配置方式,后续可能迭代。 · unresolved

LTX2RefinementStage 阶段名称推断问题 正确性

ShangmingCai 询问改动是否影响 `LTX2RefinementStage`,因为之前存在蛇形命名推断。

结论:作者确认并修复了推断阶段名称的回归问题,并添加了回归测试。 · 已解决

沉默 TypeError 改为 warning robustness

ShangmingCai 建议将 `_find_conditioner_dtype` 中的 `pass` 改为 `warning` 或 `raise`,避免静默失败。

结论:作者接受并改为 `logger.warning`,同时保留回退到 `float32` 的非致命行为。 · 已解决

风险与影响

  1. 中央配置维护性风险_get_extra_allowed_modules_for_role 集中维护所有管道的白名单,当管道数量增长时容易臃肿,且新增管道时需修改基类文件,违反开闭原则,可能成为后续维护瓶颈。
  2. 角色过滤遗漏风险filter_modules_for_role 依赖 get_module_role 的映射,若有新模块未加入映射或白名单,可能导致角色加载缺少必要模块,造成运行时错误。
  3. dtype 回退精度风险_resolve_runtime_dtype 在无法推断时回退到 float32,可能改变预期精度(如原设计为 bfloat16),影响模型输出质量。该回退发生在 Hunyuan3D 和 Qwen 的 stage 中。
  4. 测试覆盖风险:虽然新增了单元测试,但未覆盖所有管道(如 Flux2Klein、GLM Image)的白名单路径,且未集成端到端 role 切换测试,回归风险仍存在。

影响范围:所有使用 ComposedPipelineBase 的 diffusion 管道(Flux2, Flux2Klein, Qwen Image/Edit, GLM Image, Wan I2V, Wan I2V DMD, LTX2, MOVA, Hunyuan3D 等)。对用户:启用 --disagg-role 参数时,不同角色实例将只加载所需模块,减少 GPU 内存占用(如 Encoder 角色不再加载 Denoiser 和 Decoder 模块)。对团队:引入新的开发契约——新增管道需要在基类注册白名单,新增 stage 需定义 role_affinity 属性,否则会影响分解部署的正确性。

中央配置维护 角色过滤遗漏 dtype 回退精度

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论