Prhub

#38112 [CPU] Added faster exp routine for lower precision data types.

原始 PR 作者 almayne 合并时间 2026-04-23 21:14 文件变更 2 提交数 6 评论 19 代码增减 +76 / -6

执行摘要

为 ARM CPU BF16/FP16 注意力添加快速 exp

降低低精度注意力计算中exp运算的开销,提升ARM CPU推理性能。PR body观察到对neoverse V1有3-4%加速。

值得精读,特别是其平台特定优化与通用代码的分离策略,以及编译期分派的实践。设计讨论展现了在ISA无关代码中集成特殊优化的权衡,对后续类似改动有参考价值。

讨论亮点
  • fadara01要求将最初的运行时分支is_neon_f16改为编译期分派,避免性能开销。最终采用if constexpr + #ifdef __aarch64__方案。
  • bigPYJ1151提议在attention_impl_t中添加标志或将分派逻辑移到DEFINE_FAST_EXP宏内,但almayne指出fast_exp也被MoE使用,宏作用域内无query_t类型。最终决定在cpu_attn_impl.h中添加平台宏分支,虽然稍显杂乱但更清晰。
  • fadara01bigPYJ1151对x86平台是否需要假实现展开讨论,最终选择在cpu_attn_impl.h中加入#ifdef __aarch64__条件编译,避免在x86上引入未使用的代码。

实现拆解

  1. 定义快速exp函数(csrc/cpu/cpu_arch_macros.h:在DEFINE_FAST_EXP宏中新增fast_exp_f16,它基于Arm优化库中的expf算法,使用三阶多项式近似 exp(r) ≈ 1 + r + r^2*(c3 + c2*r),精度在FP16/BF16下为1ULP,输入范围限制在[-87.683, 88.376]之外时饱和到0或inf。
  2. 集成到softmax(csrc/cpu/cpu_attn_impl.hpp:在apply_softmaxapply_softcap函数中,通过编译期常量IsReducedPrecision判断query类型是否为BF16/Half,再结合#ifdef __aarch64__宏,在ARM平台上为低精度类型调用fast_exp_f16,否则回退到原生fast_expstd::exp
  3. 避免运行时分支:所有分派都在编译期完成,通过if constexpr和预处理宏确保非ARM平台不受影响,ARM上低精度路径使用更快的近似。
  4. 配套注释更新:在宏定义中添加详细说明,解释与标准exp_u20的差异及精度保证。
  5. 性能验证:使用benchmark_cpu_attn.py在neoverse V1上测试,batch_size=64, seq_len=512, bfloat16,观察到3-4%的注意力计算加速。
文件 模块 状态 重要度
csrc/cpu/cpu_arch_macros.h CPU 内核 modified 6.88
csrc/cpu/cpu_attn_impl.hpp CPU 内核 modified 6.87

关键符号

fast_exp_f16 apply_softmax apply_softcap

关键源码片段

csrc/cpu/cpu_arch_macros.h core-logic

定义了 fast_exp_f16 函数,是性能优化的核心实现。

// 在 DEFINE_FAST_EXP 宏中新增 fast_exp_f16
// 基于 Arm 优化库 expf AdvSIMD,但使用更低阶多项式
auto neon_expf_f16 = [&](float32x4_t values) __attribute__((always_inline)) {
    // 输入范围限制:[-87.683, 88.376] 外饱和
    const uint32x4_t lt_lower = vcltq_f32(values, lower_bound);
    const uint32x4_t gt_upper = vcgtq_f32(values, upper_bound);
    float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2));
    float32x4_t r = vfmsq_n_f32(values, n, ln2);
    uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23);
    float32x4_t r2 = vmulq_f32(r, r);
    // exp(r) ≈ 1 + r + r^2*(c3 + c2*r),三阶多项式
    float32x4_t q = vfmaq_n_f32(f_c3, r, f_c2);
    float32x4_t s = vaddq_f32(vdupq_n_f32(1.0f), r);
    float32x4_t p = vfmaq_f32(s, q, r2);
    float32x4_t y = vreinterpretq_f32_u32(vaddq_u32(vreinterpretq_u32_f32(p), e));
    y = vbslq_f32(lt_lower, vdupq_n_f32(0.0f), y);
    y = vbslq_f32(gt_upper, vdupq_n_f32(INFINITY), y);
    return y;
};
// 包装为 FP32Vec16 接口,处理四个 128 位向量
auto fast_exp_f16 = [&](const vec_op::FP32Vec16& vec) __attribute__((always_inline)) {
    float32x4x4_t result;
    result.val[0] = neon_expf_f16(vec.reg.val[0]);
    result.val[1] = neon_expf_f16(vec.reg.val[1]);
    result.val[2] = neon_expf_f16(vec.reg.val[2]);
    result.val[3] = neon_expf_f16(vec.reg.val[3]);
    return vec_op::FP32Vec16(result);
};
csrc/cpu/cpu_attn_impl.hpp core-logic

集成了 fast_exp_f16,在 softmax/softcap 中根据类型和平台选择 exp 实现。

// 在 apply_softmax 函数中,exp 计算部分
#if defined(DEFINE_FAST_EXP)
    // 编译期判断是否低精度类型(BF16/Half)
    bool constexpr IsReducedPrecision =
        std::is_same_v<query_t, c10::BFloat16> ||
        std::is_same_v<query_t, c10::Half>;
    // 仅 ARM 平台使用 fast_exp_f16,否则用通用 fast_exp
    #ifdef __aarch64__
    if constexpr (IsReducedPrecision) {
        vec = fast_exp_f16(vec);
    } else
    #endif
    {
        vec = fast_exp(vec);
    }
    // 保存为 prob_buffer_t
    prob_buffer_vec_t output_vec(vec);
    output_vec.save(curr_prob_buffer_iter);
#else
    // 不使用快速 exp 时,调用 std::exp
    vec.save(curr_logits_buffer_iter);
    for (int32_t k = 0; k < 16; ++k) {
        curr_logits_buffer_iter[k] = std::exp(curr_logits_buffer_iter[k]);
    }
    vec = vec_op::FP32Vec16(curr_logits_buffer_iter);
#endif

评论区精华

避免运行时分支,改为编译期分派 性能

fadara01 建议将最初的运行时分支 `is_neon_f16` 改为模板参数或编译期判断,以避免运行时开销。almayne 改为在 cpu_attn_impl.h 中使用 `if constexpr` 和 `#ifdef __aarch64__`。

结论:采用 `#ifdef __aarch64__` + `if constexpr (IsReducedPrecision)` 方案,在 ARM 上为低精度类型启用 fast_exp_f16。 · 已解决

fast_exp_f16 的集成位置 设计

bigPYJ1151 提议将分派逻辑放在 DEFINE_FAST_EXP 宏内,但 almayne 指出该宏也被 MoE 使用,且宏作用域内没有 query_t 类型。最终决定在 cpu_attn_impl.h 中加平台检查。

结论:在 cpu_attn_impl.h 中使用 `#ifdef __aarch64__` 条件编译,明确平台相关优化。 · 已解决

x86 平台是否需要假实现 设计

bigPYJ1151 最初建议在 x86 上也定义 fast_exp_f16 为 fast_exp 的别名,fadara01 认为这会在 ISA 无关代码中引入混乱。最终选择在 cpu_attn_impl.h 中通过 `#ifdef __aarch64__` 避免在 x86 上使用 fast_exp_f16。

结论:不添加假实现,仅在 ARM 平台编译并调用 fast_exp_f16。 · 已解决

风险与影响

精度风险:fast_exp_f16仅针对BF16/FP16保证1ULP精度,若误用于FP32注意力可能产生较大误差(但当前仅在低精度类型下启用)。平台风险:仅ARM NEON生效,x86使用原有fast_exp,无影响。性能风险:已通过基准测试验证提升,无回归。兼容性:不涉及接口变化,仅内部实现优化。

对ARM CPU用户:BF16/FP16注意力计算获得3-4%性能提升,推理延迟降低。对团队:新增ARM平台特定代码,需维护注释和精度约束,但改动集中在两个文件,影响可控。其他后端无任何变化。

精度约束限于低精度类型 仅 ARM 平台生效 未包含新测试

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论