Prhub

#39445 [Feat] CPU fp8 attn for AMX/AVX-512

原始 PR 作者 tianmu-li 合并时间 2026-04-29 20:43 文件变更 19 提交数 44 评论 79 代码增减 +1068 / -284

执行摘要

CPU 后端新增 FP8 KV 缓存量化支持

之前 CPU 平台完全禁止 FP8 KV 缓存:抛出硬错误或静默降级为 BF16,导致用户无法利用 FP8 降低 KV 缓存内存。PR 移除了这些限制,并为 AMX/AVX-512 平台实现了高效的 FP8 内核。

建议精读该 PR,尤其是 generate_cpu_attn_dispatch.py 的调度设计、TileGemm 模板的扩展方式以及 FP8 去量化与 GEMM 的融合技巧。对关注 CPU 推理性能优化的读者有较高参考价值。

讨论亮点

主要 review 讨论由 bigPYJ1151 主导,核心包括:建议统一 FP8 和普通 C++ 接口,避免独立函数;使用 c10::Float8_e4m3fn/e5m2 官方类型替代 uint8;通过 if constexpr 合并 TileGemm 特化以减少代码重复;将 scratch buffer 移出循环以提升性能;重命名 kv_cache_t 为 kv_cache_scalar_t 以区分内部类型;调整 dispatch 宏以包含 FP8 路径,并为 AVX2 和 AVX512 分别生成不同 case。作者逐一采纳并解决,最终获得 reviewer 的 'LGTM' 批准。

实现拆解

  1. 新增 FP8 内核和工具函数:创建 csrc/cpu/cpu_attn_fp8.hpp,实现标量量化和反量化函数、AMX 友好的 reshape 内核(半字打包 K 和子组打包 V),以及 VEC 路径的加载辅助。
  2. 扩展 GEMM 模板以支持 FP8 缓存类型:修改 TileGemm224/TileGemm122(AMX)和 TileGemm82(VEC),通过添加 q_buffer_t 和 kv_cache_t 模板参数,并在 AMX 路径中插入 deq_tile_amx / prepare_b_tile 去量化步骤;VEC 路径重写 load_b_pair_vec 以处理 FP8。
  3. 重构调度代码:更新 generate_cpu_attn_dispatch.py,将 kv_cache 类型编码到调度键中,生成额外的 FP8 case;cpu_attn.cpp 中的入口函数统一接收 kv_cache_dtype 和尺度参数并分派到正确的 AttentionImpl 特化。
  4. 修改 Python 后端:在 cpu_attn.py 中传递 k_scale、v_scale 和 kv_cache_dtype 给 C++ 函数;在 cpu.py 中移除旧限制,对非 x86 平台保留 NotImplementedError。
  5. 添加测试覆盖:test_cpu_attn.py 包含 71 个 FP8 相关测试,验证量化和反量化正确性、数值精度和端到端性能。
文件 模块 状态 重要度
csrc/cpu/generate_cpu_attn_dispatch.py 调度生成器 modified 8.84
csrc/cpu/cpu_attn_fp8.hpp FP8 内核 added 7.93
csrc/cpu/cpu_attn_amx.hpp AMX 实现 modified 8.38
csrc/cpu/cpu_attn_vec.hpp VEC 实现 modified 7.46
csrc/cpu/cpu_attn_impl.hpp 核心框架 modified 7.26
csrc/cpu/cpu_attn.cpp 入口函数 modified 6.98
vllm/v1/attention/backends/cpu_attn.py 后端集成 modified 6.56
vllm/platforms/cpu.py 平台层 modified 5.89
tests/kernels/attention/test_cpu_attn.py 注意力测试 modified 6.05

关键符号

encode_params _make_case generate_cases_for_isa_group fp8e4m3_to_float_scalar float_to_fp8e4m3_scalar reshape_and_cache_fp8_amx_impl deq_tile_amx prepare_b_tile load_b_pair_vec parse_fp8_kv_dtype cpu_attn_reshape_and_cache cpu_attention_with_kv_cache AttentionMainLoop::process AttentionImpl::init_from_input AttentionImpl::get_output_v_scale

分析完成后,这里会展示 LLM 生成的相对完整源码片段和详细注释。

评论区精华

统一 FP8 和非 FP8 C++ 接口 设计

bigPYJ1151 建议将 cpu_attention_with_kv_cache_fp8 和 cpu_attention_with_kv_cache 合并到一个函数,通过新增参数区分 FP8 路径。

结论:作者采纳,将两个函数合并,使用 kv_cache_dtype 参数内部调度。 · 已解决

使用 c10::Float8_e4m3fn / e5m2 类型 style

bigPYJ1151 建议使用 PyTorch 官方 FP8 类型替代 uint8。

结论:作者更新代码使用 c10::Float8_e4m3fn 和 c10::Float8_e5m2。 · 已解决

合并 TileGemm 特化减少代码重复 refactor

bigPYJ1151 建议通过 if constexpr 将 TileG224<BFloat16, fp8_t> 和 Tagem224<BFloat16, BFloat16> 统一为 TileGemm224<BFloat16, kv_cache_t>。

结论:作者采纳,用 if constexpr 分支处理 FP8 去量化。 · 已解决

将 scratch buffer 移出循环 性能

bigPYJ1151 指出 scratch 数组在循环内分配,建议移到循环外以提升性能。

结论:作者采纳,将 scratch buffer 声明移到 k_times 循环前。 · 已解决

重命名 kv_cache_t 为 kv_cache_scalar_t style

bigPYJ1151 建议重命名模板参数以区分内部 kv_cache_t 字段。

结论:作者更新所有相关文件。 · 已解决

调整 dispatch 宏生成不同平台的 FP8 case 设计

bigPYJ1151 建议为 AVX2 和 AVX512 生成不同 case,且默认不启用 FP8。

结论:作者添加了条件编译块,FP8 case 仅在 AVX512 和 AMX 平台启用。 · 已解决

风险与影响

主要风险包括:1)AMX/VEC 模板修改可能影响非 FP8 路径的正确性和性能,但已有全量测试覆盖;2)FP8 反量化增加计算开销,但通过尺度折叠和高效向量化实现性能增益;3)非 x86 平台(ARM/s390x)会抛出 NotImplementedError,需确保用户收到清晰错误信息;4)新增 dispatch 维度增加模板实例化数量和编译时间;5)尺度参数传递需保持与 GPU 后端一致,避免语义差异。

影响范围限定于 CPU 后端注意力计算模块。为 CPU 用户提供 FP8 KV 缓存选项,可减少大约 50% 的 KV 缓存内存占用(相对于 BF16),同时通过降低内存带宽需求提升吞吐量。对非 x86 平台无行为影响(保留错误提示)。涉及 C++ 内核、Python 绑定和调度代码,团队需维护新引入的 cpu_attn_fp8.hpp 文件。

核心注意力路径变更 新增 FP8 代码路径 非 x86 平台限制 模板实例化膨胀 尺度参数语义一致性

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论