Prhub

#21204 [Diffusion] Revamp Rollout Log-Prob Support with SDE/CPS for RL Post-Training

sgl-project/sglang · 作者 Rockdu · 合并时间 2026-04-09 09:00

分析状态 已生成
文件变更 17提交数 24 · 评论 33
代码增减 +944 / -11
diffusion run-ci feature test scheduling

执行摘要

为扩散模型 RL 后训练新增模块化 Rollout Log-Prob 引擎,支持 SDE/CPS/ODE 策略。

根据PR body,RL-based post-training of diffusion models (e.g., FlowGRPO) requires computing per-step log-probabilities along the denoising trajectory。因此,需要为扩散模型添加计算每一步log-probabilities的能力,以支持如FlowGRPO等强化学习后训练算法。

建议技术管理者和扩散模型开发者精读此PR,关注其模块化设计、混合模式集成以及序列并行兼容性的实现细节,为类似功能扩展提供参考。

讨论亮点

Review中主要讨论点包括:1) gemini-code-assist[bot]指出在scheduler_rl_mixin.py中潜在除零错误,当rollout_noise_level为0且log_prob_no_const为False时;2) mickqian建议为RL参数使用专用解析器以提升代码清晰度,Rockdu表示同意;3) 关于调度器接口修改,Rockdu讨论将batch参数吸收到mixin中以避免侵入式改动;4) 单元测试的注册和清理问题也被提及。

实现拆解

实现方案包括:1) 新增post_training目录,包含rl_dataclasses.py定义数据结构、scheduler_rl_mixin.py实现核心log-prob引擎、scheduler_rl_debug_mixin.py处理调试张量;2) 在SamplingParams中添加rollout相关参数,并集成验证逻辑;3) 修改FlowMatchEulerDiscreteScheduler,通过混合模式集成Rollout功能;4) 更新denoising pipeline,添加_prepare_rollout和_collect_rollout_log_probs生命周期钩子;5) 添加单元测试验证ODE、SDE、CPS模式的对齐和正确性。

文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/post_training/scheduler_rl_mixin.py diffusion/rl added 9.0
python/sglang/multimodal_gen/configs/post_training/rl_rollout.py config added 7.0
python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py scheduler modified 8.0
python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py pipeline modified 7.0
python/sglang/multimodal_gen/test/unit/test_scheduler_rollout_unit.py test added 6.0

分析完成后,这里会展示 LLM 生成的相对完整源码片段和详细注释。

关键符号

SchedulerRLMixin.flow_sde_sampling SchedulerRLMixin.prepare_rollout SchedulerRLMixin.collect_rollout_log_probs RLRolloutArgs.validate DenoisingStage._maybe_prepare_rollout

评论区精华

除法零错误风险 正确性

gemini-code-assist[bot] 指出在 scheduler_rl_mixin.py 中,当 rollout_noise_level 为 0 且 log_prob_no_const 为 False 时,noise_std_dev 为零,导致除零和 log(-inf)

结论:未在 PR 中解决,建议添加验证防止零噪声级别 · 未解决

参数解析器设计 设计

mickqian 建议为 RL 参数使用专用解析器,以提升代码模块性和清晰度

结论:Rockdu 同意,但作为后续优化,当前 PR 未实现 · 已讨论但未实现

调度器接口修改 设计

Rockdu 讨论将 batch 参数吸收到 SchedulerRLMixin 中,以避免对调度器 step 接口的侵入式改动

结论:最终修改了 step 接口添加 batch 参数,但通过 mixin 管理状态 · 已实现

单元测试注册 测试

mickqian 指出测试文件注册问题,Rockdu 随后修复

结论:已修复,移除重复注册 · 已解决

风险与影响

技术风险包括:1) 除零错误:当rollout_noise_level设置为0且rollout_log_prob_no_const为False时,log-prob计算可能失败,需添加验证;2) 性能影响:新增log-prob计算可能增加计算开销,尤其是在调试模式下收集张量;3) 兼容性:修改了调度器step接口,添加batch参数,可能影响现有代码调用;4) 序列并行集成:需要确保在分布式环境下log-prob的归并正确,测试覆盖需充分。

对用户影响:为扩散模型RL后训练提供了标准化的log-prob计算API,支持多种策略,便于集成强化学习算法;对系统影响:增加了代码复杂性和维护点,但通过模块化设计最小化侵入,确保无rollout时无额外开销;对团队影响:引入了混合模式和模块化设计,为后续功能扩展提供参考,但需团队成员熟悉新架构。

潜在除零错误 接口变更风险 性能开销 序列并行兼容性

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

执行摘要

  • 一句话:为扩散模型RL后训练新增模块化Rollout Log-Prob引擎,支持SDE/CPS/ODE策略。
  • 推荐动作:建议技术管理者和扩散模型开发者精读此PR,关注其模块化设计、混合模式集成以及序列并行兼容性的实现细节,为类似功能扩展提供参考。

功能与动机

根据PR body,RL-based post-training of diffusion models (e.g., FlowGRPO) requires computing per-step log-probabilities along the denoising trajectory。因此,需要为扩散模型添加计算每一步log-probabilities的能力,以支持如FlowGRPO等强化学习后训练算法。

实现拆解

实现方案包括:1) 新增post_training目录,包含rl_dataclasses.py定义数据结构、scheduler_rl_mixin.py实现核心log-prob引擎、scheduler_rl_debug_mixin.py处理调试张量;2) 在SamplingParams中添加rollout相关参数,并集成验证逻辑;3) 修改FlowMatchEulerDiscreteScheduler,通过混合模式集成Rollout功能;4) 更新denoising pipeline,添加_prepare_rollout和_collect_rollout_log_probs生命周期钩子;5) 添加单元测试验证ODE、SDE、CPS模式的对齐和正确性。

关键文件:

  • python/sglang/multimodal_gen/runtime/post_training/scheduler_rl_mixin.py(模块 diffusion/rl): 核心Rollout Log-Prob引擎,实现SDE/CPS/ODE采样和log-prob计算逻辑
  • python/sglang/multimodal_gen/configs/post_training/rl_rollout.py(模块 config): 定义RL Rollout参数配置、验证和CLI接口
  • python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py(模块 scheduler): 集成Rollout到现有调度器step方法,确保向后兼容
  • python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py(模块 pipeline): 在降噪循环中添加Rollout生命周期管理钩子
  • python/sglang/multimodal_gen/test/unit/test_scheduler_rollout_unit.py(模块 test): 单元测试验证Rollout引擎的正确性和对齐

关键符号:SchedulerRLMixin.flow_sde_sampling, SchedulerRLMixin.prepare_rollout, SchedulerRLMixin.collect_rollout_log_probs, RLRolloutArgs.validate, DenoisingStage._maybe_prepare_rollout

评论区精华

Review中主要讨论点包括:1) gemini-code-assist[bot]指出在scheduler_rl_mixin.py中潜在除零错误,当rollout_noise_level为0且log_prob_no_const为False时;2) mickqian建议为RL参数使用专用解析器以提升代码清晰度,Rockdu表示同意;3) 关于调度器接口修改,Rockdu讨论将batch参数吸收到mixin中以避免侵入式改动;4) 单元测试的注册和清理问题也被提及。

  • 除法零错误风险 (correctness): 未在PR中解决,建议添加验证防止零噪声级别
  • 参数解析器设计 (design): Rockdu同意,但作为后续优化,当前PR未实现
  • 调度器接口修改 (design): 最终修改了step接口添加batch参数,但通过mixin管理状态
  • 单元测试注册 (testing): 已修复,移除重复注册

风险与影响

  • 风险:技术风险包括:1) 除零错误:当rollout_noise_level设置为0且rollout_log_prob_no_const为False时,log-prob计算可能失败,需添加验证;2) 性能影响:新增log-prob计算可能增加计算开销,尤其是在调试模式下收集张量;3) 兼容性:修改了调度器step接口,添加batch参数,可能影响现有代码调用;4) 序列并行集成:需要确保在分布式环境下log-prob的归并正确,测试覆盖需充分。
  • 影响:对用户影响:为扩散模型RL后训练提供了标准化的log-prob计算API,支持多种策略,便于集成强化学习算法;对系统影响:增加了代码复杂性和维护点,但通过模块化设计最小化侵入,确保无rollout时无额外开销;对团队影响:引入了混合模式和模块化设计,为后续功能扩展提供参考,但需团队成员熟悉新架构。
  • 风险标记:潜在除零错误, 接口变更风险, 性能开销, 序列并行兼容性

关联脉络

  • 暂无明显关联 PR

参与讨论