Prhub

#22921 [NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell)

原始 PR 作者 kaixih 合并时间 2026-05-28 04:58 文件变更 4 提交数 4 评论 30 代码增减 +146 / -41

执行摘要

Blackwell SM100+ 添加 FlashInfer GDN prefill 支持

PR body 明确说明:Extends FlashInfer GDN kernel support to cover the prefill/extend path on SM100+ (Blackwell) hardware, previously raising NotImplementedError. 目的是完成 SM100+ 上的功能覆盖,从而在 Blackwell 上也能利用 FlashInfer 的高效 prefill 实现,避免回退到 Triton 后端以提升性能。

该 PR 是 Blackwell 推理栈的重要补齐,设计决策清晰(状态预分配、clamp 保护、版本校验)。值得关注:

  • SM100 / SM90 两条路径的差异(state pool vs gather/scatter)及初始化分支逻辑;
  • 如何通过预分配 bf16 output_state 消除类型转换开销;
  • 对上游 FlashInfer 版本的依赖管理。
    推荐阅读核心内核文件 gdn_flashinfer.pyextend 方法,以理解 FlashInfer 集成模式。
讨论亮点

Review 焦点:

  • 负数 padding 索引来源(hlu1 提问):cache_indices 中的 -1 含义是什么?kaixih 回复已在注释中澄清,-1 是未分配序列的填充标记,使用 clamp 将其映射到 0 号 dummy slot 以避免越界。

  • 小核融合建议(hlu1 建议使用 torch.compile 或 triton 融合前处理):kaixih 以数据流图详细回复,说明 q/k 各自经 l2norm 后喂入 FlashInfer,来自不同源且 l2norm 是自定义 Triton 核,不适合融合。hlu1 后续建议 decode 路径已经使用了并行 CUDA stream,prefill 路径因 token 数大而暂不适用。

  • l2norm strided 输入(hlu1 建议修改 l2norm_fwd 以支持 strided 输入避免 contiguous 调用):这是一个待办优化,未在此 PR 中实现。

  • bf16 state dtype 校验(yuan-luo 建议为 SM100+ prefill 增加 bf16 状态 dtype 验证,与 decode 路径对齐):kaixih 解释 FlashInfer prefill 内核本身支持 fp32/bf16 两种状态 dtype,而 decode 仅支持 bf16,因此 prefill 路径不需要额外限制。最终未添加该校验。

实现拆解

  1. 调整内核可用性判断:在 gdn_flashinfer.pyFlashInferGDNKernel.__init__ 中,将 self.use_state_pool = sm_major != 9 改为 self.use_state_pool = sm_major >= 10,使得 SM100+ 使用 state pool API,SM90 维持原有 gather/scatter 路径。同时更新了类文档字符串。

  2. 实现 SM100+ prefill 路径:在 extend 方法中移除了 if self.use_state_pool: raise NotImplementedError 的守卫。新增分支:使用 cache_indices.clamp(min=0) 处理负数填充索引,预分配 bf16 连续 output_state 供内核直接写入(避免 fp32 中间状态和类型转换),并调用 self._prefill_fn 执行 FlashInfer chunked prefill。SM90 路径基本保持不变。

  3. 增加 CUDA 版本校验:在 server_args.py_handle_linear_attn_backend 中,新增了对 --linear-attn-prefill-backend flashinfer 且 SM100+ 时的 CUDA 版本要求:若 CUDA 主版本 < 13 则抛出 ValueError,因为 CuTe DSL kernel 需要 CUDA 13+。

  4. 新增 CI 测试:添加 test/registered/4-gpu-models/test_qwen35_fp4_flashinfer.py,注册到 base-c stage 的 4-gpu-b200 运行器,使用 GSM8K 数据集 (200 条) 验证 FlashInfer 后端准确率不低于 0.95。同时从 test/manual/4-gpu-models/test_qwen35_fp4_triton.py 中删除了之前被注释掉的 FlashInfer 变体(因为功能已完备)。

文件 模块 状态 重要度
python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py 内核层 modified 7.74
python/sglang/srt/server_args.py 配置层 modified 6.52
test/registered/4-gpu-models/test_qwen35_fp4_flashinfer.py 测试 added 7.14
test/manual/4-gpu-models/test_qwen35_fp4_triton.py 测试 modified 3.36

关键符号

FlashInferGDNKernel.__init__ FlashInferGDNKernel.extend FlashInferGDNKernel.decode _handle_linear_attn_backend

关键源码片段

python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py core-logic

核心实现文件,实现了 SM100+ 上 FlashInfer GDN prefill 路径,移除 NotImplementedError 并新增 state pool 分支。

def extend(
    self,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    a: torch.Tensor,
    b: torch.Tensor,
    *,
    A_log: torch.Tensor,
    dt_bias: torch.Tensor,
    ssm_states: torch.Tensor,
    cache_indices: torch.Tensor,
    query_start_loc: torch.Tensor,
    **kwargs,
) -> tuple:
    # ... 预处理 l2norm ...
    if self.use_state_pool:
        # SM100+ 路径:使用 state pool API
        # clamp 负数 padding 索引 ( 如 -1) 到 0 ( 预留给 dummy 序列的 slot)
        ssm_cache_indices = cache_indices.clamp(min=0).to(torch.int64)
        initial_state_fi = ssm_states[ssm_cache_indices].contiguous()
        # 预分配 bf16 output_state,避免内核输出 fp32 后再转换
        output_state_fi = torch.empty_like(initial_state_fi)
        output_fi, output_state_fi = self._prefill_fn(
            q=q_fi,
            k=k_fi,
            v=v_fi,
            g=alpha_fi,
            beta=beta_fi,
            scale=None,
            initial_state=initial_state_fi,
            output_final_state=True,
            cu_seqlens=cu_seqlens_fi,
            use_qk_l2norm_in_kernel=False,
        )
        # 将更新后的状态写回 state pool
        ssm_states[ssm_cache_indices] = output_state_fi
    else:
        # SM90 路径:原有 gather/scatter 逻辑(使用 fp32 状态)
        # ... 保持不变 ...
    return output_fi, output_state_fi
python/sglang/srt/server_args.py configuration

添加了 CUDA 版本校验,确保在 SM100+ 上使用 FlashInfer prefill 时 CUDA >= 13。

# 在 _handle_linear_attn_backend 中,已有 decoder 校验之后添加:
# SM100+ FlashInfer GDN prefill 需要 CUDA 13+(CuTe DSL kernel 要求)
prefill = self.linear_attn_prefill_backend or self.linear_attn_backend
cuda_version = torch.version.cuda
cuda_major = int(cuda_version.split(".")[0]) if cuda_version is not None else 0
if (
    prefill == "flashinfer"
    and torch.cuda.is_available()
    and torch.cuda.get_device_capability()[0] >= 10
    and cuda_major < 13
):
    raise ValueError(
        "--linear-attn-prefill-backend flashinfer on SM100+ requires CUDA 13+, "
        f"got CUDA {cuda_version or 'unknown'}"
    )

评论区精华

padding 索引 -1 的来源 question

hlu1 询问代码中 `cache_indices` 的 -1 值从何而来。

结论:kaixih 在注释中说明 -1 是未分配序列的填充标记,使用 clamp(min=0) 将其映射到 0 号 dummy slot 以防止越界。 · 已解决

内核融合建议 性能

hlu1 提议使用 torch.compile 或 triton 融合 l2norm 等小核以减少启动开销。kaixih 以数据流图说明输入来自不同源且 l2norm 是自定义 Triton 核,不适合融合。后续 hlu1 认为 decode 路径已用并行 stream,prefill 因 token 数大暂不适用。

结论:当前 prefill 路径暂不进行融合,但 decode 路径已有类似优化。保持了现有设计。 · 已解决

l2norm strided 输入支持 性能

hlu1 建议修改 l2norm_fwd 内核以支持 strided 输入,从而避免 contiguous 调用。

结论:认可该建议,但未在本 PR 中实现,留待后续优化。 · acknowledged

bf16 state dtype 校验 正确性

yuan-luo 建议为 SM100+ FlashInfer prefill 增加 bf16 state dtype 验证(与 decode 路径一致)。kaixih 解释 prefill 内核支持 fp32 和 bf16 两种状态,而 decode 仅支持 bf16,因此无需额外限制。

结论:维持现状,未添加校验。 · 已解决

CI 测试注册 测试

初始 CI 注册使用了无效的 CUDA suite 名称,导致失败。kaixih 更新为 `stage='base-c', runner_config='4-gpu-b200'` 并验证通过。

结论:测试注册修正后 CI 通过。 · 已解决

风险与影响

  1. CUDA 版本兼容性:SM100+ 上使用 FlashInfer prefill 必须要求 CUDA 13+,对于尚未升级 CUDA 的用户,错误提示清晰,但可能会造成困惑(之前 Triton 后端无需此要求)。已在 server_args 中提供显式校验。
  2. 状态预分配内存开销:为了消除 fp32 中间状态,prefill 路径预分配了 bf16 output_state(torch.empty_like(initial_state_fi)),可能略微增加显存占用,但消除了后续的类型转换和 scatter 开销。
  3. 负数索引 clamp 安全性:将所有负数 padding 索引 clamp 到 0,确保内核不会访问越界状态。但 0 号 dummy slot 必须保证不会被真实序列使用,目前约定如此,若未来索引分配有变则可能造成静默错误。
  4. 性能波动:benchmark 仅针对特定模型和配置(Qwen3.5-397B、TP=8、chunked-prefill-size=163840),实际场景中加速比可能随 batch size 和序列长度变化(如 hlu1 指出的低并行度时增益较小)。

用户影响:使用 NVIDIA Blackwell (SM100+) 并搭配 CUDA 13+ 的用户,在 GDN 线性注意力模型上可以选择 --linear-attn-prefill-backend flashinfer 以获得 ~5% 端到端吞吐提升和 ~6% TTFT 改善。未升级 CUDA 或使用其他硬件的用户无影响。

系统影响:server_args.py 新增了 CUDA 版本校验,会拒绝在黑威上使用 FlashInfer prefill 但 CUDA <13 的配置。

团队影响:新增一个注册的 CI 测试(base-c, 4-gpu-b200),运行 GSM8K 准确率门控,增加了 CI 时长约 720 秒。新增的 prefill 路径需要维护与 FlashInfer 上游的兼容性。

要求 CUDA 13+ 负数索引 clamp 安全性 状态预分配内存开销

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论