Prhub

#26530 [diffusion] CI: Infer diffusion test sampling params from task type

原始 PR 作者 mickqian 合并时间 2026-06-01 11:57 文件变更 4 提交数 7 评论 2 代码增减 +98 / -66

执行摘要

扩散测试采样参数自动推断,移除冗余参数

简化 diffusion 测试用例的编写和维护,避免因人工指定错误的 sampling_params 模板(如 LTX 之前误用 T2I 模板)导致测试参数错误。PR body 明确指出:允许 test case 省略 sampling_params,从 ModelTaskType 推断共享的测试采样模板。

建议扩散测试相关开发者精读 testcase_configs.py 中的新增函数,理解推断逻辑。该 PR 的设计模式(通过 task type 自动选择参数模板)值得在其他类似需要多配置的测试场景中复制。同时注意后续新增模型时检查推断映射是否覆盖。

讨论亮点

无 Review 评论。从提交历史可见渐进式演进:先实现核心推断逻辑,然后移除冗余参数,接着添加 fallback 到 model task 以支持自定义 pipeline case,最后调整 LTX baseline。整体开发思路清晰,未出现重大争议。

实现拆解

  1. sampling_params 可选化:在 DiffusionTestCase 中将 sampling_params 字段类型改为 DiffusionSamplingParams | None,并在 __post_init__ 中若为 None 则调用 get_default_sampling_params_for_server_args 自动填充。
  2. 新增推断函数
    • _get_extra_arg_value(extras, option_name):解析 server_args.extras 中的键值对。
    • get_model_task_type_for_server_args(server_args): 先检查 --pipeline-class-name,若存在则通过 get_pipeline_config_classes 获取 task type;否则 fallback 到 model registry 中的 pipeline_config_cls.task_type。
    • get_default_sampling_params_for_model_task(task_type): 根据 task type(T2I, I2I/TI2I, T2V, I2V/TI2V, I2M 等)返回对应预定义的采样参数。
    • get_default_sampling_params_for_server_args(server_args): 组合上述两步。
  3. 清理 gpu_cases.py:移除约 35 行显式传递的 T2I_sampling_paramsTI2I_sampling_paramsT2V_sampling_params 等参数,仅保留需要覆盖默认值的 case 显式传递(如 qwen_image_t2i_cache_dit_scm_config_diffusers_1gpu 通过 replace 修改参数)。同时移除不再需要的导入符号。
  4. 更新 perf baseline:根据新的 LTX 2.3 HQ 采样参数重新运行测试,更新 perf_baselines.json 中的 ltx_2_3_hq_pipeline 性能基线数据(stages_ms、denoise_step_ms、expected_e2e_ms 等均有变化)。
  5. 锁定 ci-data 版本:在 test_utils.py 中将 SGL_TEST_FILES_CI_DATA_REVISION 更新为对应 LTX 2.3 HQ 新 GT 的提交 SHA。
    无额外配置或部署改动。
文件 模块 状态 重要度
python/sglang/multimodal_gen/test/server/testcase_configs.py 测试配置 modified 7.19
python/sglang/multimodal_gen/test/server/perf_baselines.json 性能基线 modified 5.4
python/sglang/multimodal_gen/test/server/gpu_cases.py GPU 用例 modified 4.86
python/sglang/multimodal_gen/test/test_utils.py 测试工具 modified 3.7

关键符号

_get_extra_arg_value get_model_task_type_for_server_args get_default_sampling_params_for_model_task get_default_sampling_params_for_server_args

关键源码片段

python/sglang/multimodal_gen/test/server/testcase_configs.py test-coverage

核心变更文件:添加了 sampling_params 可选化和自动推断逻辑

# 根据 server_args 推断任务类型
def get_model_task_type_for_server_args(server_args: DiffusionServerArgs) -> ModelTaskType:
    # 首先尝试从 extras 中解析 --pipeline-class-name
    pipeline_class_name = _get_extra_arg_value(
        server_args.extras, "--pipeline-class-name"
    )
    if pipeline_class_name:
        config_classes = get_pipeline_config_classes(pipeline_class_name)
        if config_classes is not None:
            pipeline_config_cls, _ = config_classes
            return pipeline_config_cls.task_type
    # 回退到从 model registry 获取模型信息
    model_info = get_model_info(server_args.model_path)
    if model_info is None:
        raise ValueError(
            f"Could not resolve model info for {server_args.model_path!r}"
        )
    return model_info.pipeline_config_cls.task_type
​
​
# 根据任务类型返回默认采样参数
def get_default_sampling_params_for_model_task(
    task_type: ModelTaskType,
) -> DiffusionSamplingParams:
    if task_type == ModelTaskType.T2I:
        return T2I_sampling_params
    if task_type in (ModelTaskType.I2I, ModelTaskType.TI2I):
        return TI2I_sampling_params
    if task_type == ModelTaskType.T2V:
        return T2V_sampling_params
    if task_type in (ModelTaskType.I2V, ModelTaskType.TI2V):
        return TI2V_sampling_params
    if task_type == ModelTaskType.I2M:
        return HUNYUAN3D_SHAPE_sampling_params
    # 默认使用 T2I 参数
    return T2I_sampling_params

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

  1. 推断映射不完整风险get_default_sampling_params_for_model_task 中定义的 task type 映射可能未覆盖所有模型变体或新增任务类型,导致新模型测试时获得不正确的默认参数。
  2. 自定义 pipeline 解析风险get_model_task_type_for_server_args 先检查 --pipeline-class-name,但若 get_pipeline_config_classes 返回 None 则会 fallback 到 model registry,此时可能得到错误的 task type。
  3. 基线更新风险:新基线依赖于 ci-data 更新,若 ci-data 未被正确拉取,测试可能因基线不匹配而失败。
  4. 回归风险:移除了大量显式参数,但覆盖了 DiffusionTestCase 构造函数的默认行为,若某些 case 依赖的默认参数与之前显式传递的不一致,可能导致测试失败。从 source_diff 来看,所有移除的 case 的 sampling_params 都正好是推断出的默认值,因此风险较低。

影响范围:仅限于 diffusion 测试模块(python/sglang/multimodal_gen/test/server/),影响开发者添加和运行测试的方式。开发者现在只需指定 server_args 即可自动获得正确的采样参数,减少样板代码和潜在错误。对用户和系统的运行无影响。对团队协作的影响:标准化了测试参数配置,降低了新模型接入测试的门槛。

推断映射可能遗漏模型类型 基线依赖外部 ci-data

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论