执行摘要
- 一句话:FP8 ASM 预填充加速 ROCm gfx950 MLA 预填充
- 推荐动作:值得精读,特别是如何设计自动检测与优雅回退、以及在元数据构建阶段预计算以避免 forward 中同步的技巧,对编写高性能 attention 后端有参考价值。
功能与动机
AITER 密集 MLA 后端默认使用 flash_attn_varlen_func 进行预填充,gfx950 上 AITER 提供了更高效的 FP8 ASM 预填充内核(mla_prefill_ps_asm_fwd + mla_reduce_v1),可显著降低 TTFT 并提高吞吐量。本 PR 自动启用该优化,无需用户调整配置。
实现拆解
- 自动检测:新增
_fp8_mla_prefill_supported() 函数,使用 on_gfx950() 和导入 mla_prefill_ps_asm_fwd、mla_reduce_v1 来判断是否支持 FP8 ASM 预填充,结果被 LRU 缓存。
- 元数据扩展:在
AiterMLAMetadata 中添加 10 个 fp8_prefill_* 可选字段,用于存储持久调度(PS)元数据,如 Q/KV 索引、work 信息、reduce 映射等。
- 缓冲区预分配:
AiterMLAMetadataBuilder.__init__ 中调用 _init_fp8_prefill_ps_buffers(),根据 max_num_reqs 和 max_prefill_qlen 通过 get_ps_metadata_info_v1 计算最大规模并预分配设备张量。
- 元数据构建:在
build() 中,当有预填充批次且 FP8 预填充启用时,调用 _build_fp8_prefill_ps_metadata(),从 common_attn_metadata.query_start_loc_cpu 切片并填充 PS 元数据;同时预计算 num_partial_tiles 并存入 fp8_prefill_num_partial_tiles 以避免 forward 中的同步。
- 前向分发:
AiterMLAImpl.forward_mha 中,当预填充存在且 FP8 启用时,调用 _mla_fp8_prefill_attn(q, k, v, attn_metadata, output),该函数使用 workspace manager 获取临时缓冲区,执行 mla_prefill_ps_asm_fwd 和 mla_reduce_v1 直接写入 output,否则回退到 flash_attn_varlen_func。
- 性能优化:审查后消除了
.to("cpu") 和 .item() 同步、不必要的张量分配和复制,并使用 workspace manager 管理中间张量。
关键文件:
vllm/v1/attention/backends/mla/rocm_aiter_mla.py(模块 注意力后端;类别 source;类型 core-logic;符号 _fp8_mla_prefill_supported, _init_fp8_prefill_ps_buffers, _build_fp8_prefill_ps_metadata, _mla_fp8_prefill_attn): 唯一的变更文件,实现了 FP8 ASM 预填充的全部逻辑,包括自动检测、元数据缓冲区预分配、元数据构建和前向分发。
关键符号:_fp8_mla_prefill_supported, _init_fp8_prefill_ps_buffers, _build_fp8_prefill_ps_metadata, _mla_fp8_prefill_attn, forward_mha
关键源码片段
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
唯一的变更文件,实现了 FP8 ASM 预填充的全部逻辑,包括自动检测、元数据缓冲区预分配、元数据构建和前向分发。
# Auto-detect FP8 ASM prefill support
@functools.lru_cache(maxsize=1)
def _fp8_mla_prefill_supported() -> bool:
"""Check if platform is gfx950 and AITER provides the required kernels."""
try:
from vllm.platforms.rocm import on_gfx950
except Exception:
return False
if not on_gfx950():
return False
try:
from aiter import mla_prefill_ps_asm_fwd, mla_reduce_v1
except Exception:
return False
return True
# FP8 prefill PS metadata fields in AiterMLAMetadata
@dataclass
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
# ... original decode fields ...
fp8_prefill_qo_indptr: torch.Tensor | None = None
fp8_prefill_kv_indptr: torch.Tensor | None = None
fp8_prefill_kv_indices: torch.Tensor | None = None
fp8_prefill_work_indptr: torch.Tensor | None = None
fp8_prefill_work_info_set: torch.Tensor | None = None
fp8_prefill_reduce_indptr: torch.Tensor | None = None
fp8_prefill_reduce_final_map: torch.Tensor | None = None
fp8_prefill_reduce_partial_map: torch.Tensor | None = None
fp8_prefill_max_q_len: int | None = None
fp8_prefill_num_partial_tiles: int | None = None
# Buffer pre-allocation in builder __init__
class AiterMLAMetadataBuilder(...):
def __init__(self, ...):
super().__init__(...)
self._fp8_prefill_enabled = _fp8_mla_prefill_supported()
if self._fp8_prefill_enabled:
max_prefill_qlen = min(
vllm_config.model_config.max_model_len,
vllm_config.scheduler_config.max_num_batched_tokens,
)
self._init_fp8_prefill_ps_buffers(
max_num_reqs, max_prefill_qlen, device
)
评论区精华
风险与影响
- 风险:技术风险较低,因为新增路径仅在 gfx950 + AITER 提供对应内核时启用,否则静默回退到原有 flash_attn_varlen_func。主要风险包括:依赖外部 aiter 库版本;预分配缓冲区增加少量显存占用;性能提升有限(TTFT -14.8%,输出吞吐 +2.3%),且 TPOT 略有噪声;测试未覆盖回退路径及多 segment 组合。
- 影响:影响范围限于 gfx950 用户(MI355X),他们将自动获得预填充性能提升,无需配置变更。其他硬件或配置不受影响。系统层面增加少量显存占用用于 PS 元数据缓冲区。团队需要维护新增的自动检测和元数据逻辑,但代码集中在单一文件且带有清晰注释。
- 风险标记:核心路径变更, 依赖外部库, 仅限 gfx950, 缺少测试覆盖
关联脉络
- PR #42604 DeepSeekV4-Pro enable cuda graph full and piecewise mode: 同属 ROCm MLA 注意力后端优化系列,涉及同一硬件平台和模型族。
参与讨论