执行摘要
- 一句话:为扩散模型RL后训练新增模块化Rollout Log-Prob引擎,支持SDE/CPS/ODE策略。
- 推荐动作:建议技术管理者和扩散模型开发者精读此PR,关注其模块化设计、混合模式集成以及序列并行兼容性的实现细节,为类似功能扩展提供参考。
功能与动机
根据PR body,RL-based post-training of diffusion models (e.g., FlowGRPO) requires computing per-step log-probabilities along the denoising trajectory。因此,需要为扩散模型添加计算每一步log-probabilities的能力,以支持如FlowGRPO等强化学习后训练算法。
实现拆解
实现方案包括:1) 新增post_training目录,包含rl_dataclasses.py定义数据结构、scheduler_rl_mixin.py实现核心log-prob引擎、scheduler_rl_debug_mixin.py处理调试张量;2) 在SamplingParams中添加rollout相关参数,并集成验证逻辑;3) 修改FlowMatchEulerDiscreteScheduler,通过混合模式集成Rollout功能;4) 更新denoising pipeline,添加_prepare_rollout和_collect_rollout_log_probs生命周期钩子;5) 添加单元测试验证ODE、SDE、CPS模式的对齐和正确性。
关键文件:
python/sglang/multimodal_gen/runtime/post_training/scheduler_rl_mixin.py(模块 diffusion/rl): 核心Rollout Log-Prob引擎,实现SDE/CPS/ODE采样和log-prob计算逻辑
python/sglang/multimodal_gen/configs/post_training/rl_rollout.py(模块 config): 定义RL Rollout参数配置、验证和CLI接口
python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py(模块 scheduler): 集成Rollout到现有调度器step方法,确保向后兼容
python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py(模块 pipeline): 在降噪循环中添加Rollout生命周期管理钩子
python/sglang/multimodal_gen/test/unit/test_scheduler_rollout_unit.py(模块 test): 单元测试验证Rollout引擎的正确性和对齐
关键符号:SchedulerRLMixin.flow_sde_sampling, SchedulerRLMixin.prepare_rollout, SchedulerRLMixin.collect_rollout_log_probs, RLRolloutArgs.validate, DenoisingStage._maybe_prepare_rollout
评论区精华
Review中主要讨论点包括:1) gemini-code-assist[bot]指出在scheduler_rl_mixin.py中潜在除零错误,当rollout_noise_level为0且log_prob_no_const为False时;2) mickqian建议为RL参数使用专用解析器以提升代码清晰度,Rockdu表示同意;3) 关于调度器接口修改,Rockdu讨论将batch参数吸收到mixin中以避免侵入式改动;4) 单元测试的注册和清理问题也被提及。
- 除法零错误风险 (correctness): 未在PR中解决,建议添加验证防止零噪声级别
- 参数解析器设计 (design): Rockdu同意,但作为后续优化,当前PR未实现
- 调度器接口修改 (design): 最终修改了step接口添加batch参数,但通过mixin管理状态
- 单元测试注册 (testing): 已修复,移除重复注册
风险与影响
- 风险:技术风险包括:1) 除零错误:当rollout_noise_level设置为0且rollout_log_prob_no_const为False时,log-prob计算可能失败,需添加验证;2) 性能影响:新增log-prob计算可能增加计算开销,尤其是在调试模式下收集张量;3) 兼容性:修改了调度器step接口,添加batch参数,可能影响现有代码调用;4) 序列并行集成:需要确保在分布式环境下log-prob的归并正确,测试覆盖需充分。
- 影响:对用户影响:为扩散模型RL后训练提供了标准化的log-prob计算API,支持多种策略,便于集成强化学习算法;对系统影响:增加了代码复杂性和维护点,但通过模块化设计最小化侵入,确保无rollout时无额外开销;对团队影响:引入了混合模式和模块化设计,为后续功能扩展提供参考,但需团队成员熟悉新架构。
- 风险标记:潜在除零错误, 接口变更风险, 性能开销, 序列并行兼容性
关联脉络
参与讨论