Prhub

#41113 [Bugfix] Fix rope

原始 PR 作者 jeejeelee 合并时间 2026-04-29 13:42 文件变更 2 提交数 1 评论 2 代码增减 +45 / -36

执行摘要

修复 ROPE 内核中 cos/sin cache 类型硬编码为 float32 的问题

修复CI中Llama-4-Scout-FP8 TP2 fusion_e2e OOM问题(Issue #41017)。原本RoPE内核强制将cos/sin cache转换为float32,这在不必要的情况下增加了显存占用。通过支持与模型权重相同类型(如bfloat16)的cache,可减少显存消耗,尤其在TP2场景下显著降低OOM概率。

建议优先审核并合并此PR,因为它修复了实际的CI OOM问题,且实现经过充分考量(限制模板组合)。开发者可关注csrc/pos_encoding_kernels.cu中模板派发模式,未来在其他kernel中可复用此方法。

讨论亮点
  • 模板实例化膨胀风险:gemini-code-assist[bot] 指出,直接对cache类型进行笛卡尔积式派发会显著增加内核实例数量(3x3x2=18种组合),导致编译时间和二进制体积增大。建议限制cache类型仅为float32或与query类型相同。
  • review决策:最终实现采用了更受限的派发策略(AT_DISPATCH_SWITCH限制cache类型为float32或query类型),避免了完全笛卡尔积,同时在性能和灵活性间取得平衡。zyongye 审核通过(LGTM)。

实现拆解

  1. CUDA内核模板化(csrc/pos_encoding_kernels.cu):将 apply_token_rotary_embeddingapply_rotary_embedding 函数模板从固定的 float* cache_ptr 改为 cache_t* cache_ptr,新增模板参数 cache_t。在 apply_token_rotary_embedding 中,通过 static_cast<float> 将cache值转换为float用于内部计算,确保精度可控。
  2. 派发逻辑重构(csrc/pos_encoding_kernels.cu):在 rotary_embedding 入口函数中,移除显式的 cos_sin_cache.to(torch::kFloat32) 转换,改为使用 VLLM_DISPATCH_FLOATING_TYPESAT_DISPATCH_SWITCH 对cache类型进行二次派发。为避免模板组合爆炸,仅允许cache类型与query类型相同或为float32(通过 AT_DISPATCH_SWITCH 限制)。
  3. 测试增强(tests/kernels/core/test_rotary_embedding.py):为 test_rotary_embedding_opcheck 增加 dtype 参数化(torch.float32torch.bfloat16),确保两种精度下的正确性。RotaryEmbedding 实例化和query/key tensor的dtype从硬编码 torch.float32 改为参数化dtype。
文件 模块 状态 重要度
csrc/pos_encoding_kernels.cu 内核 modified 5.71
tests/kernels/core/test_rotary_embedding.py 测试 modified 4.64

关键符号

apply_token_rotary_embedding apply_rotary_embedding rotary_embedding test_rotary_embedding_opcheck

关键源码片段

csrc/pos_encoding_kernels.cu core-logic

核心更改文件:将 RoPE 内核中的 cache 类型从硬编码 float32 改为模板化,降低显存占用,修复 OOM 问题。

// csrc/pos_encoding_kernels.cu 关键片段
// 将 cache 类型从固定 float 改为模板参数 cache_t,
// 避免不必要的 float32 转换,降低显存占用。template <typename scalar_t, typename cache_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding(
    scalar_t* __restrict__ arr,
    const cache_t* __restrict__ cos_ptr,
    const cache_t* __restrict__ sin_ptr,
    int rot_offset, int embed_dim,
    const bool inverse) {
  // ... 内部计算仍使用 float 保证精度
  float cos_f = static_cast<float>(VLLM_LDG(cos_ptr + x_index));
  float sin_f = static_cast<float>(VLLM_LDG(sin_ptr + x_index));
  // ...
}
tests/kernels/core/test_rotary_embedding.py test-coverage

测试增强:增加 bfloat16 参数化覆盖,确保低精度场景正确性。

# tests/kernels/core/test_rotary_embedding.py 关键片段
# 新增 dtype 参数化,验证 float32 和 bfloat16 两种精度
@pytest.mark.parametrize(
    "dtype", [torch.float32, torch.bfloat16]
)
def test_rotary_embedding_opcheck(
    ...
    dtype,
):
    # 使用参数化 dtype 初始化 RotaryEmbedding 和 tensor
    rot = RotaryEmbedding(
        head_size, rotary_dim, max_position, base, is_neox_style, dtype
    )
    query = torch.randn(
        batch_size, seq_len, num_heads, head_stride,
        dtype=dtype, device=device
    )
    # ...

评论区精华

模板实例化膨胀风险 设计

gemini-code-assist[bot] 指出嵌套派发会导致 18 种模板组合,增加编译时间和二进制大小。建议限制 cache 类型范围。

结论:实现采用了 AT_DISPATCH_SWITCH 限制 cache 类型仅可为 float32 或与 query 类型相同,避免了完全笛卡尔积。 · 已解决

风险与影响

  • 回归风险(低):变更涉及CUDA内核模板化,若cache_t推导错误可能导致编译失败或运行时错误。但测试覆盖了float32和bfloat16两种类型,且仅允许有限组合,风险可控。
  • 编译时间/二进制体积(中):虽然限制了派发组合,但仍比原单一float32版本增加了一些实例化,可能略微增加编译时间。但相比完全笛卡尔积,影响已降至可接受范围。
  • 功能兼容性(低):移除了显式的to(kFloat32),若调用方传入非预期类型的cache,可能因类型不匹配崩溃。但VLLM内部cache类型与模型权重类型一致,实际使用中不会出现问题。
  • 影响范围:仅限于RoPE CUDA内核及对应测试。所有使用RoPE的模型(含Llama系列、DeepSeek等)均会受益。
  • 用户影响:对于使用低精度(bfloat16/float16)的模型,RoPE计算可直接复用已有cache类型,减少显存占用和精度转换开销,可能降低OOM概率。
  • 系统影响:编译后的二进制体积略有增加,但运行时性能无负面影响。
核心路径变更 编译时间增加

关联 Issue

#41017 [CI Failure]: Llama-4-Scout-FP8 tp2 fusion_e2e OOM on H100

完整报告

参与讨论