Prhub

#26287 [RL] Fix FP8 skip matching for trailing-dot prefixes

原始 PR 作者 zianglih 合并时间 2026-05-27 04:30 文件变更 3 提交数 2 评论 4 代码增减 +63 / -0

执行摘要

修复 trailing-dot 前缀的 FP8 skip 匹配

Mixed MXFP8/BF16 检查点使用尾随点前缀(如 model.layers.34.)来保持层内所有子模块为 BF16 高精度。PR #23467 切换到 dot-boundary 匹配后,这样的前缀无法匹配子模块(如 model.layers.34.mlp.experts.0.down_proj),导致本应保持 BF16 的层被量化,造成模型精度回归。

建议精读。本 PR 展示了如何在保持向后兼容的前提下修复 dot-boundary 匹配导致的回归,其测试策略值得参考:单元测试聚焦边界条件,集成测试使用真实模型和评估管道验证精度。

讨论亮点

无实质 review 讨论。gemini-code-assist[bot] 的自动评论未提出具体问题,b8zhong 直接批准。

实现拆解

  1. 核心逻辑修复python/sglang/srt/layers/quantization/utils.py):在 _module_path_match 函数开头添加 ignored = ignored.rstrip(".")prefix = prefix.rstrip("."),归一化尾随点,确保后续的 dot-boundary 匹配逻辑能正确处理尾随点前缀。
  2. 单元测试增强test/registered/quant/test_is_layer_skipped.py):新增 test_trailing_dot_prefix_matches_child_modules 方法,验证 model.layers.34. 可以匹配子模块 model.layers.34.mlp.experts.0.down_proj,同时确保 model.layers.340. 不会误匹配。
  3. 端到端集成测试test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py):新增 FlashinferTrtllmGenMoeBackendMXFP8MixedBF16Base 基类和 TestFlashinferTrtllmRoutedMxfp8MixedBF16 测试类,使用真实的混合精度模型 zianglih/JoyAI-LLM-Flash-MXFP8-last-6-BF16,在 flashinfer_trtllm_routed 后端下以 tp=4 运行 GSM8K 评测,断言得分 > 0.92。
文件 模块 状态 重要度
python/sglang/srt/layers/quantization/utils.py 核心逻辑 modified 5.69
test/registered/quant/test_is_layer_skipped.py 测试 modified 5.78
test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py 测试 modified 6.97

关键符号

_module_path_match

关键源码片段

python/sglang/srt/layers/quantization/utils.py core-logic

核心修复:在 _module_path_match 中添加 rstrip 归一化尾随点,直接修复匹配逻辑

# python/sglang/srt/layers/quantization/utils.pydef _module_path_match(ignored: str, prefix: str) -> bool:
    # Match on dotted module-path boundaries so that `mlp.gate` does NOT
    # match `mlp.gate_up_proj`. Needed for quant configs (e.g. Qwen3.6-FP8)
    # whose `modules_to_not_convert` lists MoE-template names like `mlp.gate`
    # that collide with fused dense MLP names by plain substring.
    # 修复:去除尾随点,使 `model.layers.34.` 能匹配子模块
    # 此前由于 dot-boundary 匹配,`model.layers.34.` 不会匹配
    # `model.layers.34.mlp.experts.0.down_proj`,导致应保持 BF16 的层被量化
    ignored = ignored.rstrip(".")
    prefix = prefix.rstrip(".")
    if ignored == prefix:
        return True
    if prefix.startswith(ignored + "."):
        return True
    return ("." + ignored + ".") in ("." + prefix + ".")
test/registered/quant/test_is_layer_skipped.py test-coverage

单元测试验证尾随点前缀的匹配行为,包括正确匹配和拒绝误匹配

# test/registered/quant/test_is_layer_skipped.pyclass TestIsLayerSkipped(CustomTestCase):
    # ... 已有测试方法 ...
​
    def test_trailing_dot_prefix_matches_child_modules(self):
        # Mixed-precision checkpoints may use a trailing-dot layer prefix to keep
        # every module under the layer in higher precision.
        ignored = ["model.layers.34."]
        # model.layers.34. 应匹配 model.layers.34.mlp.experts.0.down_proj
        self.assertTrue(
            is_layer_skipped("model.layers.34.mlp.experts.0.down_proj", ignored, {})
        )
        # 但不应误匹配 model.layers.340(注意数字不同)
        self.assertFalse(
            is_layer_skipped("model.layers.340.mlp.experts.0.down_proj", ignored, {})
        )
test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py test-coverage

端到端集成测试:使用真实混合精度模型验证修复后的精度达标

# test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.pyclass FlashinferTrtllmGenMoeBackendMXFP8MixedBF16Base:
    backend = None
​
    @classmethod
    def setUpClass(cls):
        # 使用真实的混合精度模型:最后 6 层保持 BF16,其余为 MXFP8
        cls.model = "zianglih/JoyAI-LLM-Flash-MXFP8-last-6-BF16"
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            env={**os.environ, "SGLANG_ENABLE_JIT_DEEPGEMM": "False"},
            other_args=[
                "--kv-cache-dtype", "bf16",
                "--fp8-gemm-backend", "flashinfer_cutlass",
                "--moe-runner-backend", cls.backend,
                "--tp-size", "4",
                "--trust-remote-code",
            ],
        )
​
    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)
​
    def test_gsm8k(self):
        args = SimpleNamespace(
            base_url=self.base_url,
            model=self.model,
            eval_name="gsm8k",
            api="completion",
            max_tokens=512,
            num_examples=200,
            num_threads=128,
        )
        metrics = run_eval(args)
        print(f"{metrics=}")
        # 修复后精度应达到 0.92 以上
        self.assertGreater(metrics["score"], 0.92)
​
​
class TestFlashinferTrtllmRoutedMxfp8MixedBF16(
    FlashinferTrtllmGenMoeBackendMXFP8MixedBF16Base, CustomTestCase
):
    backend = "flashinfer_trtllm_routed"

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

变更本身风险低:仅添加 2 行代码(rstrip),且已有单元测试验证边界情况。潜在风险是 rstrip 可能去除用户有意保留的尾随点,但合理的使用场景中尾随点仅用于表示前缀不完整,去除后不影响匹配语义。

直接影响:修复了使用尾随点前缀的 Mixed MXFP8/BF16 检查点的精度问题,用户无需修改配置即可正确跳过指定层。间接影响:_module_path_match 被多个量化路径使用,此修改不会破坏现有行为(因为之前尾随点前缀实际失效,修复后相当于扩展了匹配能力)。

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论