执行摘要
- 一句话:避免 ROCm 编译时重复 bf16 转换
- 推荐动作:该 PR 改动小巧、聚焦,验证充分(性能、精度、FX dump),建议合并。值得注意的设计决策:通过额外 buffer 而非修改全局 dtype 来避免精度影响,以及将条件守卫精确限定在编译时快路径。
功能与动机
在 GPT-OSS 的 decode 编译图中,逐层调用 cos_sin_cache.to(query.device, dtype=query.dtype) 产生重复的 bf16 转换节点。预计算并复用 bf16 缓存可删除这些冗余节点,同时保持现有运行语义。
实现拆解
- 注册预计算 bf16 缓存:在
RotaryEmbeddingBase.__init__ 中,当启用 AITER 且主缓存 dtype 非 bf16 时,额外注册一个 cos_sin_cache_bf16 buffer,值为 cache.to(torch.bfloat16),仅保留在 GPU 上。
- 添加编译时快路径:在
_match_cos_sin_cache_dtype 中,当满足条件:use_aiter、torch.compiler.is_compiling()、且 query dtype 为 bf16 时,直接返回预计算的 cos_sin_cache_bf16(需设备匹配),避免执行 fallback 的 .to() 调用。
- 保持回退路径:不满足 AITER 编译快路径的条件时,逻辑不变。
关键文件:
vllm/model_executor/layers/rotary_embedding/base.py(模块 模型执行层;类别 source;类型 data-contract;符号 RotaryEmbeddingBase.init, RotaryEmbeddingBase._match_cos_sin_cache_dtype): 唯一修改的文件,在 RotaryEmbeddingBase 的 init 和 _match_cos_sin_cache_dtype 中添加了预计算 bf16 缓存和编译时快路径。
关键符号:RotaryEmbeddingBase.init, RotaryEmbeddingBase._match_cos_sin_cache_dtype
关键源码片段
vllm/model_executor/layers/rotary_embedding/base.py
唯一修改的文件,在 RotaryEmbeddingBase 的 init 和 _match_cos_sin_cache_dtype 中添加了预计算 bf16 缓存和编译时快路径。
# 在 __init__ 中,主缓存初始化后添加预计算 bf16 缓存
if init_cache:
cache = self._compute_cos_sin_cache()
if not self.use_flashinfer:
cache = cache.to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
# 为 AITER 编译路径预计算 bf16 缓存,避免逐层重复 cast 节点
if self.use_aiter and cache.dtype != torch.bfloat16:
self.register_buffer(
"cos_sin_cache_bf16",
cache.to(torch.bfloat16),
persistent=False,
)
else:
# 明确置为 None,确保 _match_cos_sin_cache_dtype 中 getattr 能正确判断
self.cos_sin_cache_bf16 = None
# 在 _match_cos_sin_cache_dtype 中添加快路径
def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> torch.Tensor:
cos_sin_cache = self.cos_sin_cache
# 如果设备且 dtype 已匹配,直接返回(原逻辑)
if (
cos_sin_cache.device == query.device
and self.cos_sin_cache.dtype == query.dtype
):
return cos_sin_cache
# AITER 编译快路径:查询为 bf16 且预计算缓冲可用时直接复用
if (
self.use_aiter
and torch.compiler.is_compiling()
and query.dtype == torch.bfloat16
):
cache_bf16 = getattr(self, "cos_sin_cache_bf16", None)
if cache_bf16 is not None and cache_bf16.device == query.device:
return cache_bf16
# 回退路径:执行 device/dtype 转换
cos_sin_cache = cos_sin_cache.to(query.device, dtype=query.dtype)
# 编译时避免修改 buffer
if torch.compiler.is_compiling():
return cos_sin_cache
self.cos_sin_cache = cos_sin_cache
return cos_sin_cache
评论区精华
reviewer tjtanaa 询问为何不直接修改主缓存 dtype 以节省内存。作者 akii96 回应:GPT-OSS 硬编码了 dtype=torch.float32,直接改 dtype 会改变模型预期精度;而编译时 buffer mutation 被 cudagraph 阻断导致逐层产生冗余 cast。额外 buffer 仅 4MB,且仅限于 AITER 编译路径,不影响其他行为。
- 额外 buffer 内存开销 (design): akii96 解释:GPT-OSS 硬编码 dtype=float32,直接改 dtype 会改变模型精度;编译时 buffer mutation 被 cudagraph 阻止导致逐层冗余 cast;额外 buffer 仅约 4MB,且仅限于 AITER 编译路径,不影响其他行为。
风险与影响
- 风险:风险较低:仅影响 ROCm AITER 编译路径,且通过条件守卫(use_aiter、is_compiling、query.dtype == bf16)隔离。内存增加约 4MB(取决于 max_position_embeddings 和 rotary_dim),在可接受范围内。非 ROCm 或非编译路径行为完全不变。
- 影响:影响范围:仅在使用 ROCm AITER、torch.compile 且 query 为 bf16 的 GPT-OSS 模型场景下生效。吞吐提升 1.7%-1.9%,TPOT 降低 1.6%-3.1%。LM-eval GSM8K 精度指标保持不变。
- 风险标记:内存增加约4MB
关联脉络
- PR #39177 [ROCm][Perf] Expose AITER MoE sorting dispatch policy via env var: 均为 ROCm AITER 相关性能优化,共享 use_aiter 基础设施。
参与讨论