Prhub

#22604 [Diffusion] Standalone Rollout API + Denoising Environment Backpass + SP-Aligned Log-Prob for T2I Post-Training

原始 PR 作者 Rockdu 合并时间 2026-04-15 10:10 文件变更 19 提交数 55 评论 8 代码增减 +1305 / -76

执行摘要

为扩散模型后训练新增独立 Rollout API,支持轨迹收集和序列并行对齐的对数概率。

根据 PR body,RL-based 后训练(如 FlowGRPO)需要专用 serving endpoint 来隔离通用 T2I 生成路径,并提供足够的轨迹元数据以在外部重放 rollout 进行策略梯度计算,同时在 Sequence Parallelism 下保持数值稳定性。

建议精读此 PR 以学习其设计模式:混入类(RolloutDenoisingMixin)分离核心逻辑、SP 对齐策略(避免 all_reduce)和按样本粒度 API 设计。关注 _kwargs_to_cpu 的递归问题和文件组织,可能需后续优化。

讨论亮点

review 中,gemini-code-assist[bot] 指出 _kwargs_to_cpu 函数不是完全递归的,对于嵌套序列(如列表的列表中的张量)无法正确处理,且可能将元组转换为列表,建议实现一个更健壮的递归版本以保持类型一致性。mickqian 评论了文件路径问题,建议将 mixin 文件移至 pipeline_configs/mixins 目录。这些讨论凸显了代码健壮性和项目结构的一致性问题,但 PR 已合并,可能未完全解决。

实现拆解

  1. 新增独立 Rollout HTTP API:在 python/sglang/multimodal_gen/runtime/entrypoints/post_training/rollout_api.py 中定义 POST /rollout/generate 端点,使用 RolloutRequest/RolloutResponse 结构,通过 _extract_single_sample_tensor_slice_rollout_trajectory_for_sample 实现按样本粒度输出,便于 RL 训练器处理单个轨迹。
  2. 引入 RolloutDenoisingMixin:在 python/sglang/multimodal_gen/runtime/post_training/rollout_denoising_mixin.py 中创建混入类,添加 _maybe_init_denoising_env_collection_maybe_append_dit_trajectory_step 等方法,用于收集冻结的 transformer kwargs 和 DiT 轨迹(如原始噪声潜变量),通过 _kwargs_to_cpu 将张量移至 CPU。
  3. 增强序列并行支持:在 python/sglang/multimodal_gen/runtime/post_training/sp_utils.py 中新增 gather_stacked_latents_for_sp 等函数,确保 SP 下张量被正确收集到完整形状,避免数值漂移;SP 对齐对数概率通过全缓冲区噪声计算实现,无需额外集合通信。
  4. 添加张量序列化工具:在 python/sglang/multimodal_gen/runtime/entrypoints/post_training/utils.py 中定义 tensor_to_base64_maybe_serialize 函数,用于将轨迹数据序列化为 base64 字符串以便 HTTP 传输。
  5. 配套测试与配置:新增 python/sglang/multimodal_gen/test/unit/test_rollout_api.py 和修改 test_scheduler_rollout_unit.py 等测试文件,验证序列化、API 响应和数值正确性;同时添加 pipeline config mixins 如 qwen_image_rollout_pipeline_mixin.py 用于模型特定收集逻辑。
文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/entrypoints/post_training/rollout_api.py 后训练入口 added 9.17
python/sglang/multimodal_gen/runtime/post_training/rollout_denoising_mixin.py 去噪混入 added 8.81
python/sglang/multimodal_gen/runtime/post_training/sp_utils.py 序列并行工具 added 8.12
python/sglang/multimodal_gen/runtime/entrypoints/post_training/utils.py 序列化工具 added 8.07
python/sglang/multimodal_gen/test/unit/test_rollout_api.py 单元测试 added 7.74

关键符号

_extract_single_sample_tensor rollout_generate _kwargs_to_cpu _maybe_prepare_rollout _maybe_collect_rollout_log_probs gather_stacked_latents_for_sp tensor_to_base64 _maybe_serialize

关键源码片段

python/sglang/multimodal_gen/runtime/entrypoints/post_training/rollout_api.py entrypoint

定义了独立的 Rollout HTTP API 端点,是实现用户接口的核心文件。

def _extract_single_sample_tensor(obj: Any, sample_idx: int, batch_size: int) -> Any:
    """递归提取单个样本的张量,用于将批次数据拆分为 per-sample 粒度。"""
    if isinstance(obj, torch.Tensor):
        # 如果张量的第一维等于批次大小,则提取对应样本
        if obj.dim() >= 1 and obj.shape[0] == batch_size:
            return obj[sample_idx].contiguous()
        return obj
    if isinstance(obj, dict):
        # 递归处理字典中的每个值
        return {
            k: _extract_single_sample_tensor(v, sample_idx, batch_size)
            for k, v in obj.items()
        }
    if isinstance(obj, list):
        # 递归处理列表
        return [_extract_single_sample_tensor(v, sample_idx, batch_size) for v in obj]
    if isinstance(obj, tuple):
        # 递归处理元组并保持类型
        return tuple(
            _extract_single_sample_tensor(v, sample_idx, batch_size) for v in obj
        )
    return obj # 非张量或容器类型直接返回
python/sglang/multimodal_gen/runtime/post_training/rollout_denoising_mixin.py dependency-wiring

包含 RolloutDenoisingMixin 混入类,负责收集 Denoising 环境数据和 DiT 轨迹。

def _kwargs_to_cpu(d: Any) -> Any:
    """将嵌套结构中的张量移至 CPU,但当前实现未完全递归处理嵌套序列。"""
    if isinstance(d, torch.Tensor):
        return d.detach().cpu()
    if isinstance(d, dict):
        return {k: _kwargs_to_cpu(v) for k, v in d.items()}
    if isinstance(d, list):
        return [_kwargs_to_cpu(v) for v in d]
    if isinstance(d, tuple):
        return tuple(_kwargs_to_cpu(v) for v in d) # 注意:这里可能未处理嵌套元组中的张量
    return dclass RolloutDenoisingMixin:
    def _maybe_init_denoising_env_collection(
        self,
        batch,
        pipeline_config,
        image_kwargs: dict[str, Any],
        pos_cond_kwargs: dict[str, Any],
        neg_cond_kwargs: dict[str, Any],
        guidance: torch.Tensor | None,
    ) -> None:
        """根据标志初始化 Denoising 环境收集状态。"""
        collect_env = batch.rollout_return_denoising_env
        collect_traj = batch.rollout_return_dit_trajectory
        if not (collect_env or collect_traj):
            batch._rollout_dit_env_state = None
            return
        # 使用 pipeline_config 的清理函数处理 kwargs
        sanitize = getattr(pipeline_config, "sanitize_dit_env_kwargs", lambda x: x)
        if collect_env:
            env = RolloutDenoisingEnv(
                image_kwargs=_kwargs_to_cpu(sanitize(image_kwargs)),
                pos_cond_kwargs=_kwargs_to_cpu(sanitize(pos_cond_kwargs)),
                neg_cond_kwargs=(
                    _kwargs_to_cpu(sanitize(neg_cond_kwargs))
                    if neg_cond_kwargs
                    else None
                ),
                guidance=guidance.detach().cpu() if guidance is not None else None,
            )
        else:
            env = None
        # 初始化状态以存储轨迹步骤
        batch._rollout_dit_env_state = {
            "env": env,
            "step_latents": [],
            "timesteps": [],
        }

评论区精华

_kwargs_to_cpu 函数的递归问题 正确性

gemini-code-assist[bot] 指出函数未完全递归处理嵌套序列(如列表的列表中的张量),且可能将元组转换为列表,影响下游代码。

结论:建议实现更健壮的递归版本以保持类型一致性,但 PR 中未显示是否已修复。 · 待处理

文件路径组织 style

mickqian 评论建议将 mixin 文件移至 pipeline_configs/mixins 目录以保持项目结构一致。

结论:评论简短,可能已接受或忽略,但未在 PR 中明确调整。 · addressed

风险与影响

技术风险包括:

  1. 回归风险:新 API 和混入类可能影响现有扩散模型生成路径,特别是在 DenoisingStage 的修改中(python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py),若生命周期钩子调用顺序错误可能导致逻辑错误。
  2. 性能开销:张量序列化和反序列化(通过 tensor_to_base64)可能增加 CPU 和内存开销,尤其在返回大型轨迹时。
  3. 数值稳定性:SP 对齐对数概率依赖于全缓冲区噪声计算,若噪声生成或收集逻辑有误,可能导致跨 rank 不一致。
  4. 兼容性:新 API 仅支持 T2I 任务,未来扩展至其他任务(如 T2V)时需额外适配。

对用户(RL 训练者)而言,新增了专用端点,简化了轨迹数据获取,提升训练效率;对系统,增加了新的 HTTP 路由和数据处理模块,需要维护和测试;对团队,引入了混入类和 SP 工具,需理解设计以进行后续开发。影响范围集中在扩散模型后训练模块,不影响核心推理路径。

新 API 接口 序列化开销 SP 逻辑复杂 生命周期钩子风险

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论