报告:启用FlashInfer GDN MTP验证于SM100+ (Blackwell)
执行摘要
本PR在SM100+(Blackwell)GPU上启用了FlashInfer GDN的MTP验证,之前该路径抛出 NotImplementedError。通过导入bf16状态MTP kernel并移除配置guard,FlashInfer现在可以成为SM100+上的默认MTP后端。基准测试显示吞吐量和TPOT略优于Triton,准确率(GSM8K、GPQA)达标。变更涉及4个文件,核心改动约15行,配套测试增加了FlashInfer专用评估类。
功能与动机
“Enables FlashInfer GDN MTP (speculative decoding) verify on SM100+ (Blackwell) hardware, previously raising NotImplementedError. SM90 (Hopper) MTP was already supported; this PR completes SM100+ coverage.”
之前SM100+用户无法使用FlashInfer进行MTP验证,必须回退到Triton。完成SM100+覆盖后,用户可以在Blackwell硬件上获得统一且略有提升的推测解码性能。
实现拆解
-
导入bf16 MTP kernel(gdn_flashinfer.py):在 _get_flashinfer_gdn_kernels 中新增从 flashinfer.gdn_kernels.gdn_decode_bf16_state 导入 gated_delta_rule_mtp 并重命名为 gated_delta_rule_mtp_bf16,在返回元组中暴露该函数。
-
添加bf16适配器(gdn_flashinfer.py):新增内部函数 _mtp_bf16_adapted,将FlashInfer bf16状态MTP kernel包装成与现有verify接口兼容的形式(处理 intermediate_states_buffer 的切片和 A_log 的float转换)。target_verify 根据 ssm_states.dtype 选择调用fp32路径或bf16适配路径。
-
移除配置guard(server_args.py):在 _handle_linear_attn_backend 中删除 and self.speculative_algorithm is None 条件,使SM100+在启用任意推测算法(包括MTP)时都能自动默认 linear_attn_decode_backend='flashinfer'。
-
更新后端路由注释(gdn_backend.py):修正 verify_kernel 选择逻辑的注释,指出SM100+现在可以使用FlashInfer进行MTP验证(原来错误地声称不支持)。
-
测试配套(test_qwen35_fp4_mtp.py):
- 将重复的启动参数抽取为
MTP_BASE_ARGS 常量。
- 提取
_run_mtp_gsm8k 工具函数,复用评估逻辑。
- 新增
TestQwen35FP4MTPFlashInfer 类,通过 --linear-attn-decode-backend flashinfer --enforce-disable-flashinfer-allreduce-fusion 启动服务器并运行gsm8k评估。
- 原有
TestQwen35FP4MTP 类保持不变,仍使用Triton后端。
- 测试注册时间从340s调整为740s,以适应两个测试类。
python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py
核心变更文件:导入 bf16 状态 MTP kernel,新增 _mtp_bf16_adapted 适配器函数,统一 SM90 和 SM100+ 的 verify 路径。
def _get_flashinfer_gdn_kernels():
"""Lazy import for FlashInfer GDN prefill, decode and verify (MTP) kernels.
Returns (available, prefill_fn, mtp_fn, decode_fn, mtp_bf16_fn).
"""
global _flashinfer_gdn_available, _flashinfer_chunk_gated_delta_rule, _flashinfer_gated_delta_rule_mtp, _flashinfer_gated_delta_rule_decode, _flashinfer_gated_delta_rule_mtp_bf16
if _flashinfer_gdn_available is None:
try:
os.environ.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1")
from flashinfer.gdn_decode import (
gated_delta_rule_decode_pretranspose,
gated_delta_rule_mtp,
)
from flashinfer.gdn_kernels.gdn_decode_bf16_state import (
gated_delta_rule_mtp as gated_delta_rule_mtp_bf16, # 新增:导入 bf16 状态 MTP kernel
)
from flashinfer.gdn_prefill import chunk_gated_delta_rule
_flashinfer_chunk_gated_delta_rule = chunk_gated_delta_rule
_flashinfer_gated_delta_rule_mtp = gated_delta_rule_mtp
_flashinfer_gated_delta_rule_mtp_bf16 = gated_delta_rule_mtp_bf16 # 新增:保存 bf16 版本函数句柄
_flashinfer_gated_delta_rule_decode = gated_delta_rule_decode_pretranspose
_flashinfer_gdn_available = (
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9
)
if _flashinfer_gdn_available:
logger.info("FlashInfer GDN kernels loaded successfully")
except (ImportError, RuntimeError) as e:
logger.warning(f"FlashInfer GDN kernels not available: {e}")
_flashinfer_gdn_available = False
_flashinfer_gated_delta_rule_decode = None
return (
_flashinfer_gdn_available,
_flashinfer_chunk_gated_delta_rule,
_flashinfer_gated_delta_rule_mtp,
_flashinfer_gated_delta_rule_decode,
_flashinfer_gated_delta_rule_mtp_bf16, # 新增:在返回元组中提供 bf16 版本
)
test/registered/models_e2e/test_qwen35_fp4_mtp.py
测试覆盖:新增 TestQwen35FP4MTPFlashInfer 类验证 FlashInfer 后端下 MTP 的 gsm8k 准确率,同时抽取公共参数和工具函数降低重复。
def _run_mtp_gsm8k(test_case):
"""工具函数:启动 GSM8K 评估并验证准确率与推测接受长度。"""
args = SimpleNamespace(
model=test_case.model,
eval_name="gsm8k",
num_shots=5,
num_examples=200,
max_tokens=16000,
num_threads=128,
repeat=1,
temperature=0.6,
top_p=0.95,
top_k=20,
base_url=test_case.base_url,
host="http://127.0.0.1",
port=int(test_case.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
test_case.assertGreaterEqual(
metrics["score"], ACC_THRESHOLDS[test_case.model]["gsm8k"]
)
server_info = requests.get(test_case.base_url + "/server_info")
avg_spec_accept_length = server_info.json()["internal_states"][0][
"avg_spec_accept_length"
]
print(f"{avg_spec_accept_length=}")
test_case.assertGreater(avg_spec_accept_length, 3.3)
class TestQwen35FP4MTPFlashInfer(ReasoningTokenUsageMixin, CustomTestCase):
"""验证 FlashInfer 后端下的 MTP 推理准确率(GSM8K)。"""
reasoning_parser_name = "qwen3"
@classmethod
def setUpClass(cls):
cls.model = QWEN35_FP4_MODEL
cls.base_url = DEFAULT_URL_FOR_TEST
cls.init_reasoning_token_verifier()
envs.SGLANG_ENABLE_SPEC_V2.set(True)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=MTP_BASE_ARGS
+ [
"--linear-attn-decode-backend",
"flashinfer", # 指定 FlashInfer 后端
"--enforce-disable-flashinfer-allreduce-fusion", # 避免融合引入干扰
],
)
@classmethod
def tearDownClass(cls):
envs.SGLANG_ENABLE_SPEC_V2.set(False)
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
_run_mtp_gsm8k(self)
评论区精华
Fridge003: “Can we add a test for this usage. Maybe under test/registered/4-gpu-models/test_qwen35_fp4_flashinfer.py”
YAMY1234: “Added under test/registered/4-gpu-models/test_qwen35_models.py, thanks!”(实际添加在 test_qwen35_fp4_mtp.py 中)
审查者关心新功能的测试覆盖,作者快速响应并添加了专门的FlashInfer MTP测试类。
风险与影响
- 风险:需要 FlashInfer >= 0.6.7;bf16适配路径依赖上游bug修复(flashinfer#3147);测试仅覆盖gsm8k单配置,不覆盖topk>1等场景。
- 影响:SM100+用户无需手动干预即可获得FlashInfer MTP加速;团队需维持两条MTP后端,但代码复用度高;性能提升约1-5%,无显著退化。
关联脉络
该PR完成了GDN MTP在Blackwell上的最后一环,与以下内容关联:
- 上游FlashInfer PR #2810(padding index guard)和 #3147(OOB crash fix)是功能正确性的基础。
- 关联Issue #2679 和 #2810 跟踪了bf16状态MTP kernel的设计与padding修复。
- 同仓库近期PR #26866(Support spec v2 tree drafting)和 #26424(topk=1 fastpath)都属于speculative decoding的持续优化链条。
参与讨论