Prhub

#23273 [NVIDIA] [GDN] Enable FlashInfer MTP verify on SM100+ (Blackwell)

原始 PR 作者 wenscarl 合并时间 2026-06-02 09:56 文件变更 4 提交数 6 评论 13 代码增减 +164 / -87

执行摘要

启用 FlashInfer GDN MTP 验证于 SM100+

Enables FlashInfer GDN MTP (speculative decoding) verify on SM100+ (Blackwell) hardware, previously raising NotImplementedError. SM90 (Hopper) MTP was already supported; this PR completes SM100+ coverage.

建议关注 gdn_flashinfer.py 中 _mtp_bf16_adapted 函数的适配技巧(中间状态切片、A_log float 转换),以及测试文件如何通过抽取公共参数和工具函数降低重复代码。该 PR 设计简洁,适合作为跨硬件后端子类化的参考案例。

讨论亮点

在代码审查中,Fridge003 要求为 FlashInfer MTP 使用场景添加测试,并建议放在 test/registered/4-gpu-models/test_qwen35_fp4_flashinfer.py。YAMY1234 回应已在 test_qwen35_fp4_mtp.py 中添加了 FlashInfer 专用测试类 TestQwen35FP4MTPFlashInfer。此外,nvpohanh 指出 H20 CI 失败是已知问题,已由另一个 PR 修复,建议合并。

实现拆解

  1. 在 gdn_flashinfer.py 中新增对 flashinfer.gdn_kernels.gdn_decode_bf16_state 中 gated_delta_rule_mtp 的导入(命名为 gated_delta_rule_mtp_bf16),并在 _get_flashinfer_gdn_kernels 返回元组中补充该函数。
  2. 新增内部函数 _mtp_bf16_adapted,将 FlashInfer bf16 状态 MTP kernel 包装成与现有 verify 接口兼容的形式(处理中间状态张量切片和 A_log 数据类型转换),并在 FlashInferGDNKernel.target_verify 中根据 state dtype 选择调用 fp32 或 bf16 路径。
  3. 在 server_args.py 的 _handle_linear_attn_backend 中移除对 speculative_algorithm is None 的条件判断,使 SM100+ 在启用 MTP 时也能自动默认 FlashInfer 作为线性注意力解码后端。
  4. 在 gdn_backend.py 中更新 verify kernel 选择逻辑的注释,反映 SM100+ 现在可以通过 FlashInfer 进行 MTP 验证(原来被错误地阻止)。
  5. 在 test/registered/models_e2e/test_qwen35_fp4_mtp.py 中将公共启动参数抽取为 MTP_BASE_ARGS 常量,提取 _run_mtp_gsm8k 工具函数,并新增 TestQwen35FP4MTPFlashInfer 测试类,使用 --linear-attn-decode-backend flashinfer 启动服务器并执行 gsm8k 评估,同时保留原有 Triton 测试类。
  6. 延长测试注册预估时间(340s → 740s)以容纳新增的 FlashInfer 测试轮次。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py 注意力内核 modified 7.6
test/registered/models_e2e/test_qwen35_fp4_mtp.py 模型测试 modified 7.28
python/sglang/srt/server_args.py 服务器配置 modified 5.14
python/sglang/srt/layers/attention/linear/gdn_backend.py 后端路由 modified 4.62

关键符号

_mtp_bf16_adapted _run_mtp_gsm8k

关键源码片段

python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py dependency-wiring

核心变更文件:导入 bf16 状态 MTP kernel,新增 _mtp_bf16_adapted 适配器函数,统一 SM90 和 SM100+ 的 verify 路径。

def _get_flashinfer_gdn_kernels():
    """Lazy import for FlashInfer GDN prefill, decode and verify (MTP) kernels.    Returns (available, prefill_fn, mtp_fn, decode_fn, mtp_bf16_fn).
    """
    global _flashinfer_gdn_available, _flashinfer_chunk_gated_delta_rule, _flashinfer_gated_delta_rule_mtp, _flashinfer_gated_delta_rule_decode, _flashinfer_gated_delta_rule_mtp_bf16
    if _flashinfer_gdn_available is None:
        try:
            os.environ.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1")
​
            from flashinfer.gdn_decode import (
                gated_delta_rule_decode_pretranspose,
                gated_delta_rule_mtp,
            )
            from flashinfer.gdn_kernels.gdn_decode_bf16_state import (
                gated_delta_rule_mtp as gated_delta_rule_mtp_bf16, # 新增:导入 bf16 状态 MTP kernel
            )
            from flashinfer.gdn_prefill import chunk_gated_delta_rule
​
            _flashinfer_chunk_gated_delta_rule = chunk_gated_delta_rule
            _flashinfer_gated_delta_rule_mtp = gated_delta_rule_mtp
            _flashinfer_gated_delta_rule_mtp_bf16 = gated_delta_rule_mtp_bf16 # 新增:保存 bf16 版本函数句柄
            _flashinfer_gated_delta_rule_decode = gated_delta_rule_decode_pretranspose
            _flashinfer_gdn_available = (
                torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9
            )
            if _flashinfer_gdn_available:
                logger.info("FlashInfer GDN kernels loaded successfully")
        except (ImportError, RuntimeError) as e:
            logger.warning(f"FlashInfer GDN kernels not available: {e}")
            _flashinfer_gdn_available = False
            _flashinfer_gated_delta_rule_decode = None
    return (
        _flashinfer_gdn_available,
        _flashinfer_chunk_gated_delta_rule,
        _flashinfer_gated_delta_rule_mtp,
        _flashinfer_gated_delta_rule_decode,
        _flashinfer_gated_delta_rule_mtp_bf16, # 新增:在返回元组中提供 bf16 版本
    )
test/registered/models_e2e/test_qwen35_fp4_mtp.py test-coverage

测试覆盖:新增 TestQwen35FP4MTPFlashInfer 类验证 FlashInfer 后端下 MTP 的 gsm8k 准确率,同时抽取公共参数和工具函数降低重复。

def _run_mtp_gsm8k(test_case):
    """工具函数:启动 GSM8K 评估并验证准确率与推测接受长度。"""
    args = SimpleNamespace(
        model=test_case.model,
        eval_name="gsm8k",
        num_shots=5,
        num_examples=200,
        max_tokens=16000,
        num_threads=128,
        repeat=1,
        temperature=0.6,
        top_p=0.95,
        top_k=20,
        base_url=test_case.base_url,
        host="http://127.0.0.1",
        port=int(test_case.base_url.split(":")[-1]),
    )
    metrics = run_eval(args)
    print(f"{metrics=}")
    test_case.assertGreaterEqual(
        metrics["score"], ACC_THRESHOLDS[test_case.model]["gsm8k"]
    )
​
    server_info = requests.get(test_case.base_url + "/server_info")
    avg_spec_accept_length = server_info.json()["internal_states"][0][
        "avg_spec_accept_length"
    ]
    print(f"{avg_spec_accept_length=}")
    test_case.assertGreater(avg_spec_accept_length, 3.3)
​
​
class TestQwen35FP4MTPFlashInfer(ReasoningTokenUsageMixin, CustomTestCase):
    """验证 FlashInfer 后端下的 MTP 推理准确率(GSM8K)。"""
    reasoning_parser_name = "qwen3"
​
    @classmethod
    def setUpClass(cls):
        cls.model = QWEN35_FP4_MODEL
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.init_reasoning_token_verifier()
        envs.SGLANG_ENABLE_SPEC_V2.set(True)
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=MTP_BASE_ARGS
            + [
                "--linear-attn-decode-backend",
                "flashinfer", # 指定 FlashInfer 后端
                "--enforce-disable-flashinfer-allreduce-fusion", # 避免融合引入干扰
            ],
        )
​
    @classmethod
    def tearDownClass(cls):
        envs.SGLANG_ENABLE_SPEC_V2.set(False)
        kill_process_tree(cls.process.pid)
​
    def test_gsm8k(self):
        _run_mtp_gsm8k(self)

评论区精华

为 FlashInfer MTP 添加测试覆盖 测试

Fridge003 要求添加 FlashInfer MTP 的使用测试,并在 server_args.py 的 diff 上提出。YAMY1234 回应已添加测试类 TestQwen35FP4MTPFlashInfer 在 test_qwen35_fp4_mtp.py 中。

结论:测试已添加,通过添加 TestQwen35FP4MTPFlashInfer 类使用 --linear-attn-decode-backend flashinfer 参数启动服务器并运行 gsm8k 评估。 · 已解决

风险与影响

  1. 依赖更新:需要 FlashInfer >= 0.6.7,否则导入 bf16 状态 kernel 会直接失败。
  2. 新代码路径:_mtp_bf16_adapted 涉及张量重排和 dtype 转换,若 intermediate_states_buffer 形状不匹配可能导致 OOB 写入(上游 flashinfer#3147 已修复)。
  3. 测试覆盖:仅通过 gsm8k(200 样本)和 GPQA 验证,未覆盖 topk>1、不同状态维度或长上下文场景。
  4. 性能退化:基准测试显示 FlashInfer MTP 与 Triton 性能相近(1-5% 优势),无显著退化风险。

用户影响:SM100+ 用户无需手动指定 --linear-attn-decode-backend 即可在 MTP 场景下获得略有提升的性能。系统影响:FlashInfer 成为 SM100+ 且 mamba_ssm_dtype=bf16 时 MTP 解码的默认后端。团队影响:需同时维护 Triton 和 FlashInfer 两条 MTP 验证路径,但核心逻辑高度复用。

依赖 FlashInfer >=0.6.7 新 bf16 适配路径 测试仅覆盖 gsm8k 单配置

关联 Issue

#2679 feat(gdn): add BF16 state kernel with MTP support beyond T>4 with intermediate caching.
#2810 feat(gdn): add padding index guard for bf16 decode kernel
#3145 Fix OOB crash in intermediate_states indexing for GDN decode MTP kernel

完整报告

参与讨论