Prhub

#39627 [XPU] enable triton attention test on XPU by removing cuda device binding

原始 PR 作者 yma11 合并时间 2026-04-20 20:57 文件变更 3 提交数 2 评论 3 代码增减 +52 / -36

执行摘要

移除 Triton 注意力测试中的 CUDA 硬编码,支持 XPU 等异构硬件平台。

根据 PR 描述,Triton 注意力测试无法在 XPU 上运行,因为设备被硬编码绑定到 CUDA。PR 的目的是通过使用 current_platform.device_type 移除这种绑定,使测试能够在 XPU 等平台上执行。

该 PR 展示了如何将硬编码的设备依赖重构为平台无关的测试模式,值得测试开发人员参考。虽然变更简单,但其中关于 torch.set_default_device 可能引起测试污染的讨论具有普遍警示意义。建议关注后续是否会有 PR 采纳 reviewer 的建议改用上下文管理器。

讨论亮点

reviewer gemini-code-assist[bot] 指出使用 torch.set_default_device 会修改 PyTorch 进程的全局状态,可能导致测试污染(后续测试可能期望默认的 'cpu' 设备)。建议改用 with torch.device(DEVICE_TYPE): 上下文管理器以确保测试后恢复默认设备。但 PR 最终被批准合并,未采纳此建议,表明团队可能认为当前风险可控或后续会单独处理。

实现拆解

  1. 统一设备类型获取:在每个测试文件顶部导入 current_platform 并定义 DEVICE_TYPE = current_platform.device_type,作为后续所有设备分配的统一来源。
  2. 替换张量设备分配:将测试函数中所有创建张量时硬编码的 device="cuda" 参数替换为 device=DEVICE_TYPE,包括随机张量、零张量、索引张量等。
  3. 替换默认设备设置:在 test_triton_unified_attention.py 中,将 torch.set_default_device("cuda") 替换为 torch.set_default_device(DEVICE_TYPE),确保后续张量创建默认使用正确设备。
  4. 测试配套:本次变更仅涉及测试文件,未修改任何生产代码、配置或部署脚本,属于纯测试覆盖调整。
文件 模块 状态 重要度
tests/kernels/attention/test_triton_decode_attention.py 注意力测试 modified 5.4
tests/kernels/attention/test_triton_prefill_attention.py 注意力测试 modified 4.76
tests/kernels/attention/test_triton_unified_attention.py 注意力测试 modified 4.11

关键符号

test_decode_attention test_decode_attention_fp8 test_context_attention test_context_attention_sliding_window test_triton_unified_attn test_triton_unified_attn_fp16_input_fp8_output

关键源码片段

tests/kernels/attention/test_triton_decode_attention.py test-coverage

修改了 decode 注意力测试,将硬编码的 CUDA 设备替换为动态 DEVICE_TYPE,是本次 PR 中变更量最大的文件。

import pytest
import torch
from vllm.platforms import current_platform # 新增导入,用于获取当前平台设备类型
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwdDEVICE_TYPE = current_platform.device_type # 定义全局设备类型常量,替代硬编码的 'cuda'@pytest.mark.parametrize("B", [3, 5])
# ... 其他参数化装饰器def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
    assert CACHE_SIZE % PAGE_SIZE == 0
    dtype = torch.bfloat16
    seq_len = L
    sm_scale = 1.0 / (D_QK**0.5)
    num_kv_splits = 8
​
    num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
    # 所有张量创建都使用 DEVICE_TYPE 而非硬编码的 "cuda"
    req_to_page = torch.randint(
        0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device=DEVICE_TYPE
    )
    req_to_token = req_to_page * PAGE_SIZE
    req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
    req_to_token = req_to_token + torch.arange(PAGE_SIZE, device=DEVICE_TYPE).view(1, 1, -1)
    req_to_token = req_to_token.view(B, -1)
    req_to_token = req_to_token[:, :seq_len].contiguous()
​
    q = torch.randn(B, H_Q, D_QK, dtype=dtype, device=DEVICE_TYPE)
    k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device=DEVICE_TYPE)
    v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device=DEVICE_TYPE)
    o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=DEVICE_TYPE)
    lse = torch.zeros(B, H_Q, dtype=dtype, device=DEVICE_TYPE)
    b_seq_len = torch.full((B,), seq_len, device=DEVICE_TYPE)
    attn_logits = torch.empty(
        (B, H_Q, num_kv_splits, D_V + 1),
        dtype=torch.float32,
        device=DEVICE_TYPE,
    )
    # 调用内核函数,其内部应能处理不同设备上的张量
    decode_attention_fwd(
        q, k_buffer, v_buffer, o, lse, req_to_token, b_seq_len,
        attn_logits, num_kv_splits, sm_scale
    )
    # ... 后续断言逻辑保持不变

评论区精华

torch.set_default_device 可能导致测试污染 测试

reviewer gemini-code-assist[bot] 指出在测试函数中使用 torch.set_default_device 会修改全局状态,可能导致后续测试因默认设备意外改变而失败。建议改用 with torch.device(DEVICE_TYPE): 上下文管理器。

结论:PR 被批准合并,未采纳改用上下文管理器的建议。 · 已解决

风险与影响

测试污染风险:在 test_triton_unified_attention.py 中使用 torch.set_default_device 可能影响同一进程中后续测试的默认设备设置,导致非确定性失败。
平台兼容性风险:依赖 current_platform.device_type 假设该接口在所有目标平台上都能正确返回可用的设备类型字符串(如 'xpu', 'cuda'),若平台支持不完整可能导致测试失败。
回归风险:变更仅涉及测试设备分配,不影响核心注意力内核逻辑,因此功能回归风险极低。

对用户:无直接影响,属于内部测试改进。
对系统:扩展了 Triton 注意力内核在非 CUDA 硬件(如 Intel XPU)上的测试能力,有助于发现和预防跨平台兼容性问题。
对团队:提升了 CI/CD 流水线在异构硬件环境下的测试覆盖率,为未来支持更多硬件平台奠定基础。

测试污染风险 平台兼容性依赖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论