Prhub

#24552 [Gemma4] Add test for MTP models

原始 PR 作者 kpham-sgl 合并时间 2026-05-28 12:36 文件变更 3 提交数 8 评论 3 代码增减 +535 / -0

执行摘要

为 Gemma4 MTP 模型添加 GSM8K 端到端测试

本 PR 为 PR #24436 添加测试覆盖,使用 GSM8K 数据集验证 Gemma4 26B-A4B、31B 和 Frozen-KV E4B 三种 MTP 模型在 speculative decoding 下的准确率和接受长度。测试注册为 CI extra-stage,确保 PR 触及 Gemma4 模型路径时自动触发验证。

值得精读测试设计模式,了解 SGLang 中 speculative decoding 端到端测试的编写方法(服务器启动、配置验证、评估执行、断言输出)。可作为新模型 MTP 测试的模板。关注 CI 注册和阶段命名规则,避免后续维护成本。

讨论亮点

在 review 中,开发者对平均接受长度阈值提出了修改意见(原 0.0 过于宽松),作者在 commit 3144ab6 中将其提升至 1.5,并移除了无用的 setUpClass 方法。此外,Frozen-KV 测试增加了 SGLANG_ENABLE_ASYNC_ASSERT 环境变量以在 NaN/Inf/OOB 时快速失败。

实现拆解

  1. 创建测试文件:新建三个 Python 文件置于 test/registered/spec/ 目录,分别对应 26B-A4B、31B 和 Frozen-KV E4B 模型。每个文件定义测试类(继承 CustomTestCase)并通过 register_cuda_ci 注册到 CI 阶段(extra-a 或 base-b)。

  2. 服务器启动辅助:每个测试类实现 _server_env(设置环境变量,如禁用 speculative V2 或启用异步断言)、_common_server_args(固定参数:attention-backend=triton、dtype=bfloat16、context-length=2048 等)、_server_args(根据 topk 动态添加 speculative 参数:算法 NEXTN、draft 模型路径、步数、topk、draft token 数)。

  3. 核心测试流程 _run_gsm8k_mtp:启动 SGLang 服务器 → 调用 /flush_cache 清空缓存 → 查询 /server_info 验证配置(speculative_eagle_topk、disable_cuda_graph)→ 运行 GSM8K 评估(200 样本,5-shot)→ 获取准确率和平均推测接受长度 → 清理进程。

  4. 断言与报告:具体测试方法 test_gsm8k_topk1test_gsm8k_topk3 调用 _run_gsm8k_mtp 并断言准确率超过阈值(26B-A4B: 0.42, 31B: 0.775, E4B: 0.65)且平均接受长度 ≥ 1.5。在 CI 环境下输出 GitHub Step Summary。

  5. CI 适配调整:通过多次提交调整注册参数:从 nightly 迁移到 extra-a、修正 stage 命名(stage-b → base-b)、提高接受长度阈值(0.0 → 1.5)、增加异步断言环境变量,确保与最新 CI 框架兼容。

文件 模块 状态 重要度
test/registered/spec/test_gemma4_mtp_26b_a4b_extra.py 集成测试 added 7.09
test/registered/spec/test_gemma4_mtp_31b_extra.py 集成测试 added 7.09
test/registered/spec/test_frozen_kv_mtp.py 集成测试 added 7.04

关键符号

get_server_info get_avg_spec_accept_length _run_gsm8k_mtp test_gsm8k_topk1 test_gsm8k_topk3

关键源码片段

test/registered/spec/test_frozen_kv_mtp.py test-coverage

Frozen-KV E4B MTP 模型的端到端测试,验证 topk=1 和 topk=3 下的 GSM8K 准确率和接受长度,包含更完整的断言和 CI 报告输出。

class TestFrozenKVMTP(CustomTestCase):
    # 辅助方法省略 ...
​
    def _run_gsm8k_mtp(self, topk: int) -> None:
        """启动服务器并运行 GSM8K 评估,验证准确率和接受长度。"""
        process = None
        try:
            # 使用 popen_launch_server 启动目标模型服务器
            process = popen_launch_server(
                "google/gemma-4-E4B-it", # 目标模型
                self.base_url,
                timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH * 3,
                env=self._server_env(),
                other_args=self._server_args(topk),
            )
            # 清空缓存,确保起始状态
            requests.get(self.base_url + "/flush_cache", timeout=30)
​
            # 获取服务器信息,验证配置正确
            server_info = get_server_info(self.base_url)
            self.assertEqual(
                server_info.get("speculative_eagle_topk"),
                topk,
                f"E4B: server did not start with topk={topk}",
            )
            self.assertFalse(
                bool(server_info.get("disable_cuda_graph")),
                f"E4B/topk{topk}: CUDA graph is disabled",
            )
​
            # 运行 GSM8K 评估
            metrics = run_eval(self._gsm8k_args())
            mtp_score = float(metrics["score"])
            avg_accept = get_avg_spec_accept_length(self.base_url)
        finally:
            if process is not None:
                self._stop_process(process)
​
        # 输出结果到日志和 CI 摘要
        print(
            f"[Frozen-KV MTP E4B topk={topk}] "
            f"score={mtp_score:.4f} threshold={0.65:.4f} "
            f"avg_spec_accept_length={avg_accept}"
        )
        if is_in_ci():
            write_github_step_summary(
                f"### Frozen-KV MTP E4B topk={topk}\n"
                f"score={mtp_score:.4f}\n"
                f"threshold={0.65:.4f}\n"
                f"avg_spec_accept_length={avg_accept}\n"
            )
​
        # 断言准确率和接受长度
        self.assertGreaterEqual(mtp_score, 0.65)
        self.assertIsNotNone(avg_accept)
        self.assertGreaterEqual(
            avg_accept,
            1.5,
            f"E4B/topk{topk}: accept length too low",
        )
​
    def test_gsm8k_topk1(self) -> None:
        """topk=1 时的 MTP 测试。"""
        self._run_gsm8k_mtp(topk=1)
​
    def test_gsm8k_topk3(self) -> None:
        """topk=3 时的 MTP 测试。"""
        self._run_gsm8k_mtp(topk=3)

评论区精华

平均接受长度阈值设定 正确性

Review 指出原阈值 0.0 过于宽松,因为 avg_spec_accept_length 包含 bonus token 应始终 ≥ 1,建议提升到 1.5。

结论:作者将阈值从 0.0 提高到 1.5,并移除了无用的 setUpClass 方法(commit 3144ab6)。 · 已解决

启用异步断言 测试

提交消息指出为 Frozen-KV MTP 测试启用 SGLANG_ENABLE_ASYNC_ASSERT,以在 NaN/Inf/OOB 时快速失败。

结论:在 test_frozen_kv_mtp.py 的环境变量中增加了 SGLANG_ENABLE_ASYNC_ASSERT=1(commit 1c806f2)。 · 已解决

风险与影响

  • 测试依赖 Google 发布的模型权重(google/gemma-4-*)和 assistant 模型,若 HuggingFace 仓库不可用或路径变更,测试将失败。
  • 每个测试耗时长达 720 秒(26B-A4B、31B)或 300 秒(E4B),占用 2 卡或 1 卡 GPU,可能增加 CI 排队时间和资源压力。
  • 准确率阈值基于当前观察值设定,模型微调或数值精度变化可能导致偶发假阳性失败,需要定期校准。
  • 对用户:无直接影响,纯测试变更。
  • 对系统:新增三个 CI 测试,增强对 Gemma4 MTP 功能的回归防护,确保 speculative decoding 质量。
  • 对团队:测试注册在 extra-a 和 base-b 阶段,需通过 run-ci-extra 标签触发,增加 CI 配置复杂度。
依赖外部模型权重 长耗时测试 CI 资源消耗

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论