Prhub

#42833 [ROCm][GPT-OSS] Avoid repeated compile-time `cos_sin_cache.to(bf16)` casts in rotary path

原始 PR 作者 akii96 合并时间 2026-05-27 16:22 文件变更 1 提交数 1 评论 3 代码增减 +21 / -0

执行摘要

避免 ROCm 编译时重复 bf16 转换

在 GPT-OSS 的 decode 编译图中,逐层调用 cos_sin_cache.to(query.device, dtype=query.dtype) 产生重复的 bf16 转换节点。预计算并复用 bf16 缓存可删除这些冗余节点,同时保持现有运行语义。

该 PR 改动小巧、聚焦,验证充分(性能、精度、FX dump),建议合并。值得注意的设计决策:通过额外 buffer 而非修改全局 dtype 来避免精度影响,以及将条件守卫精确限定在编译时快路径。

讨论亮点

reviewer tjtanaa 询问为何不直接修改主缓存 dtype 以节省内存。作者 akii96 回应:GPT-OSS 硬编码了 dtype=torch.float32,直接改 dtype 会改变模型预期精度;而编译时 buffer mutation 被 cudagraph 阻断导致逐层产生冗余 cast。额外 buffer 仅 4MB,且仅限于 AITER 编译路径,不影响其他行为。

实现拆解

  1. 注册预计算 bf16 缓存:在 RotaryEmbeddingBase.__init__ 中,当启用 AITER 且主缓存 dtype 非 bf16 时,额外注册一个 cos_sin_cache_bf16 buffer,值为 cache.to(torch.bfloat16),仅保留在 GPU 上。
  2. 添加编译时快路径:在 _match_cos_sin_cache_dtype 中,当满足条件:use_aitertorch.compiler.is_compiling()、且 query dtype 为 bf16 时,直接返回预计算的 cos_sin_cache_bf16(需设备匹配),避免执行 fallback 的 .to() 调用。
  3. 保持回退路径:不满足 AITER 编译快路径的条件时,逻辑不变。
文件 模块 状态 重要度
vllm/model_executor/layers/rotary_embedding/base.py 模型执行层 modified 6.53

关键符号

RotaryEmbeddingBase.__init__ RotaryEmbeddingBase._match_cos_sin_cache_dtype

关键源码片段

vllm/model_executor/layers/rotary_embedding/base.py data-contract

唯一修改的文件,在 RotaryEmbeddingBase 的 __init__ 和 _match_cos_sin_cache_dtype 中添加了预计算 bf16 缓存和编译时快路径。

# 在 __init__ 中,主缓存初始化后添加预计算 bf16 缓存
if init_cache:
    cache = self._compute_cos_sin_cache()
    if not self.use_flashinfer:
        cache = cache.to(dtype)
    self.register_buffer("cos_sin_cache", cache, persistent=False)
​
    # 为 AITER 编译路径预计算 bf16 缓存,避免逐层重复 cast 节点
    if self.use_aiter and cache.dtype != torch.bfloat16:
        self.register_buffer(
            "cos_sin_cache_bf16",
            cache.to(torch.bfloat16),
            persistent=False,
        )
    else:
        # 明确置为 None,确保 _match_cos_sin_cache_dtype 中 getattr 能正确判断
        self.cos_sin_cache_bf16 = None# 在 _match_cos_sin_cache_dtype 中添加快路径
def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> torch.Tensor:
    cos_sin_cache = self.cos_sin_cache
    # 如果设备且 dtype 已匹配,直接返回(原逻辑)
    if (
        cos_sin_cache.device == query.device
        and self.cos_sin_cache.dtype == query.dtype
    ):
        return cos_sin_cache
​
    # AITER 编译快路径:查询为 bf16 且预计算缓冲可用时直接复用
    if (
        self.use_aiter
        and torch.compiler.is_compiling()
        and query.dtype == torch.bfloat16
    ):
        cache_bf16 = getattr(self, "cos_sin_cache_bf16", None)
        if cache_bf16 is not None and cache_bf16.device == query.device:
            return cache_bf16
​
    # 回退路径:执行 device/dtype 转换
    cos_sin_cache = cos_sin_cache.to(query.device, dtype=query.dtype)
    # 编译时避免修改 buffer
    if torch.compiler.is_compiling():
        return cos_sin_cache
    self.cos_sin_cache = cos_sin_cache
    return cos_sin_cache

评论区精华

额外 buffer 内存开销 设计

tjtanaa 询问为何不直接转换主缓存 dtype 而是另注册一个 bf16 缓存,从而增加内存占用。

结论:akii96 解释:GPT-OSS 硬编码 dtype=float32,直接改 dtype 会改变模型精度;编译时 buffer mutation 被 cudagraph 阻止导致逐层冗余 cast;额外 buffer 仅约 4MB,且仅限于 AITER 编译路径,不影响其他行为。 · 已解决

风险与影响

风险较低:仅影响 ROCm AITER 编译路径,且通过条件守卫(use_aiter、is_compiling、query.dtype == bf16)隔离。内存增加约 4MB(取决于 max_position_embeddings 和 rotary_dim),在可接受范围内。非 ROCm 或非编译路径行为完全不变。

影响范围:仅在使用 ROCm AITER、torch.compile 且 query 为 bf16 的 GPT-OSS 模型场景下生效。吞吐提升 1.7%-1.9%,TPOT 降低 1.6%-3.1%。LM-eval GSM8K 精度指标保持不变。

内存增加约 4MB

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论