Prhub

#43534 [CPU][Perf] Enable fused kernels for GDN's gated delta rules

原始 PR 作者 fadara01 合并时间 2026-06-02 16:00 文件变更 11 提交数 3 评论 23 代码增减 +812 / -585

执行摘要

CPU GDN 融合内核性能提升 50%

PR 描述指出:'makes the gated delta rule impls from sglang-kernels ISA agnostic',并统一 AMX 与其他 CPU ISA 的路径。对于缺少快速 brgemm 的 ISA(如非 x86),使用 OpenBLAS 或 PyTorch BLAS 回退。同时修复了 fused_sigmoid_gating_delta_rule_update 内核中 sigmoid 计算的 bug(beta 指针使用了错误索引)。此 PR 是在 #41025 中承诺的后续优化。

值得精读,特别关注其 ISA 无关的 BLAS 回退架构和编译时分支策略,该模式可推广至其他需要跨平台性能优化的算子。建议阅读文件:csrc/cpu/sgl-kernels/blas_gemm.hgemm.h 中的 brgemm_supported 设计。

讨论亮点
  • BLAS fallback 类型不匹配争议:gemini-code-assist 指出在 blas_gemm fallback 路径中,输出缓冲为 float* 而 PyTorch stub 期望 scalar_t* 可能导致内存损坏。作者回应已使用 gemm_no_downcast_stub(而非 gemm_stub),保证了输出始终为 fp32,不存在类型不匹配。

  • 测试数值稳定性讨论:bigPYJ1151 询问 GDN 测试是否数值稳定,因其经验中随机输入容易失败。作者解释稳定性得益于修复了 sigmoid 计算中的 beta 索引 bug,且测试序列长度较短,多次试验均通过。

  • 构建兼容性要求:bigPYJ1151 建议将新增的 BLAS 函数移至单独头文件,方便未来与 sglang 源码同步;作者采纳并创建 blas_gemm.h,同时将 brgemm_supported() 宏保留在 gemm.h 中。

  • OpenBLAS 版本兼容性:aditew01 提醒 OpenBLAS <0.3.30 可能不兼容,作者表示只依赖 PyTorch libtorch 自带的 OpenBLAS 版本,无此顾虑。

实现拆解

  1. 移除旧 Python 实现并添加 C++ 融合内核:删除 recurrent_gated_delta_rule.py(223 行),避免 Python 层 GEMM 开销。在 fla.cpp 中新增 fused_sigmoid_gating_delta_rule_update_kernel_impl 和增强 chunk_gated_delta_rule_kernel_impl,将 gating、chunk decay 和 delta rule 计算融合为单个内核。

  2. 实现 ISA 无关的 BLAS 回退层:新建 blas_gemm.h,提供 blas_gemm 重载。当定义了 VLLM_HAS_OPENBLAS(由 cmake 检测)时,直接调用 OpenBLAS 的 sbgemm_/sgemm_ 进行 bf16/float GEMM;否则回退到 PyTorch 的 gemm_no_downcast_stub(输出 fp32)。在 gemm.h 中添加编译时常量 brgemm_supported(),基于 AVX512BF16/AMX 指令集判断是否可用,并修改 can_use_brgemm 模板函数与之关联。

  3. 改造 Python 调用入口:修改 gdn_attention.py,移除对旧 Python 函数 recurrent_gated_delta_rulegdn_gating 的导入,改为直接调用 ops 中注册的 C++ 函数。同时删除了两层分支(AMX 专用路径与非 AMX 路径),统一为条件编译的内核。

  4. 添加完整测试:新建 test_cpu_gdn_ops.py,包含三个参数化测试用例,覆盖 fused gdn gating、fused sigmoid gating delta rule update 和 chunk gated delta rule,与 Python 参考实现对比验证。测试配置了多种 batch 大小、序列长度和 head 维度,确保边界正确。

  5. 更新构建与 CI:修改 cpu_extension.cmake 添加 OpenBLAS 检测和 VLLM_HAS_OPENBLAS 宏;在 CI 配置文件 cpu.yamlrun-cpu-test-arm.sh 中注册新测试路径,确保 ARM 和 x86 节点都能运行。

文件 模块 状态 重要度
csrc/cpu/sgl-kernels/fla.cpp CPU 内核 modified 7.76
csrc/cpu/sgl-kernels/blas_gemm.h CPU 内核 added 7.6
csrc/cpu/sgl-kernels/gemm.h CPU 内核 modified 6.5
vllm/model_executor/layers/mamba/ops/cpu/gdn_attention.py 模型执行器 modified 6.69
tests/kernels/mamba/cpu/test_cpu_gdn_ops.py 测试 added 7.48
vllm/model_executor/layers/mamba/ops/cpu/recurrent_gated_delta_rule.py 模型执行器 removed 6.5

关键符号

chunk_gated_delta_rule_kernel_impl fused_sigmoid_gating_delta_rule_update_kernel_impl blas_gemm brgemm_supported cpu_gdn_attention_core ref_gated_delta_rule test_fused_gdn_gating_cpu test_fused_sigmoid_gating_delta_rule_update_cpu test_chunk_gated_delta_rule_cpu

关键源码片段

csrc/cpu/sgl-kernels/fla.cpp core-logic

核心 C++ 实现文件,添加了融合内核(fused_sigmoid_gating_delta_rule_update_kernel_impl 和增强 chunk_gated_delta_rule_kernel_impl),并使用编译时分支动态选择 brgemm 或 blas_gemm,是性能提升的关键。

// csrc/cpu/sgl-kernels/fla.cpp
// 展示 chunk_gated_delta_rule_kernel_impl 中的关键分支:
// 通过编译时 brgemm_supported() 决定使用 AMX brgemm 还是通用 BLAS
if constexpr (brgemm_supported()) {
    // AMX 路径:pack + brgemm
    pack_vnni<scalar_t>(k_transpose, curr_k_pad, chunk_size, qk_head_size, qk_head_size, chunk_size);
    at::native::cpublas::brgemm(
        chunk_size, chunk_size, qk_head_size, qk_head_size, chunk_size, chunk_size,
        false, curr_k_beta, k_transpose, curr_attn);
} else {
    // 非 x86 回退路径:调用 blas_gemm (BF16 GEMM 输出 FP32)
    blas_gemm(
        at::native::TransposeType::Transpose, // A = k_pad^T
        at::native::TransposeType::NoTranspose, // B = k_beta
        chunk_size, chunk_size, qk_head_size,
        1.0f,
        curr_k_pad, qk_head_size,
        curr_k_beta, qk_head_size,
        0.0f,
        curr_attn, chunk_size
    );
}
// 后续 attn = attn * decay_mask 逻辑保持不变
csrc/cpu/sgl-kernels/blas_gemm.h dependency-wiring

新增的 BLAS 回退头文件,封装了 OpenBLAS 和 PyTorch 两种调用方式,确保在无 AMX 的 CPU 上仍能高效执行 GEMM。

// csrc/cpu/sgl-kernels/blas_gemm.h
// 提供 ISA 无关的 BLAS gemm 封装
#include <ATen/native/CPUBlas.h>#if defined(VLLM_HAS_OPENBLAS)
// 直接链接 OpenBLAS 符号
inline void blas_gemm(... BFloat16 ...) {
    char transa_ = ..., transb_ = ...;
    int m_ = m, n_ = n, k_ = k;
    extern "C" void sbgemm_(...);
    sbgemm_(&transa_, &transb_, &m_, &n_, &k_, &alpha, a, &lda_, b, &ldb_, &beta, c, &ldc_);
}
inline void blas_gemm(... float ...) {
    // 类似,调用 sgemm_
}
inline void blas_gemm(... Half ...) {
    TORCH_CHECK(false, "CPU OpenBLAS hgemm is not available.");
}
#else
// 使用 PyTorch 的 cpublas stub,输出不降精度 (fp32)
template <typename scalar_t>
inline void blas_gemm(...) {
    auto gemm = at::native::cpublas::gemm_no_downcast_stub.DEFAULT;
    gemm(..., a, lda, b, ldb, beta, c, ldc);
}
#endif
csrc/cpu/sgl-kernels/gemm.h dependency-wiring

添加了 brgemm_supported() 编译时检测和 CPU_CAPABILITY_AVX512 宏,控制 can_use_brgemm 返回值,是 ISA 无关性的关键。

// csrc/cpu/sgl-kernels/gemm.h
// 添加编译时 AMX 能力检测
#include "blas_gemm.h"  // 替换原来的 <ATen/native/CPUBlas.h>#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__)
#define CPU_CAPABILITY_AVX512
#endifconstexpr bool brgemm_supported() {
#if defined(CPU_CAPABILITY_AVX512)
    return true;
#else
    return false;
#endif
}// 所有 can_use_brgemm 特化都加上 brgemm_supported() 条件

评论区精华

BLAS fallback 类型不匹配 正确性

gemini-code-assist 指出 blas_gemm fallback 中 float* 输出与 PyTorch stub 期望标量类型不一致,可能内存损坏。作者澄清使用了 gemm_no_downcast_stub,输出始终 fp32,不会 mismatch。

结论:不构成问题,已关闭。 · 已解决

测试数值稳定性质疑 测试

bigPYJ1151 指出 GDN 测试通常对随机输入敏感,容易失败。作者回应稳定性得益于本 PR 修复的 sigmoid beta bug,且序列长度较短,多次种子测试均通过。

结论:稳定,无需额外处理。 · 已解决

BLAS 函数迁移至单独头文件 设计

bigPYJ1151 建议将新增的 BLAS 函数从 gemm.h 移至单独头文件,以方便未来与 sglang 源码同步。作者采纳并创建 blas_gemm.h。

结论:已实现。 · 已解决

OpenBLAS 版本兼容性 question

aditew01 提醒 OpenBLAS <0.3.30 可能不兼容。作者说明只依赖 PyTorch libtorch 内部的 OpenBLAS,不直接链接,无需额外检查。

结论:无需处理。 · 已解决

风险与影响

  • 非 x86 回退路径验证有限blas_gemm.h 中的回退路径仅在 CI 的 ARM 环境下运行,其他非 x86 平台(如 RISC-V)未测试。
  • 移除旧实现后无 Python 回退:一旦 C++ 融合内核或 BLAS 回退崩溃,将无法恢复为纯 Python 版本,但测试覆盖率降低了风险。
  • OpenBLAS 绑定风险blas_gemm.h 直接调用 sbgemm_ 等 Fortran 命名约定符号,若链接的 OpenBLAS 版本不匹配 ABI,可能导致符号未定义或崩溃。不过由于零零通过 PyTorch 的 libtorch 动态链接,风险较低。
  • FP16 路径未实现blas_gemmat::Half 的重载直接 TORCH_CHECK(false),但 GDN 当前仅使用 bf16,因此无实际影响,但未来扩展需注意。
  • 用户影响:在支持 AMX 的 x86 CPU 和 Neoverse V2 等 ARM CPU 上,Mamba 模型的推理吞吐量显著提升;无功能变更,输出保持兼容。
  • 系统影响:减少了 Python 和 C++ 之间的上下文切换,降低了主存占用,但增加了约 200KB 的 C++ 二进制代码。
  • 团队维护:新引入的 blas_gemm.h 增加了对平台 BLAS 库的依赖,但封装在单头文件中维护成本可控。未来可从 sglang 上游同步内核更新。
非 x86 回退覆盖待增强 移除 Python 回退无容错 FP16 路径未实现

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论