Prhub

#22236 [Test] Add XPU device support to unit tests

原始 PR 作者 singhalshubham03 合并时间 2026-05-01 07:18 文件变更 3 提交数 5 评论 6 代码增减 +45 / -24

执行摘要

为三个测试文件添加 XPU 设备支持

为了支持Intel XPU设备,需要使单元测试能够无差别地在CUDA和XPU上执行。PR body中提到'Replace hardcoded 'cuda' device references with get_device() utility to enable tests to run on both CUDA and XPU devices.'

此PR展示了在SGLang中为测试添加新硬件支持的标准化方法:使用get_device()替代硬编码设备字符串,并相应调整跳过条件。虽然改动量小,但可以作为今后测试跨硬件适配的模板。建议阅读test_triton_scaled_mm.py的完整实现,以及review评论中关于安全调用torch.xpu的讨论,以避免类似问题。总体而言,值得快速浏览,但不需要深入精读。

讨论亮点

Review评论主要来自Copilot,提出了两个问题:

  1. torch.xpu.is_available()的安全调用:Copilot指出直接使用torch.xpu.is_available()在没有定义torch.xpu的PyTorch环境中可能导致AttributeError,建议用hasattr(torch, "xpu")或调用已有工具函数is_xpu()保护。此评论未得到作者或合并者回应,但PR仍被合并。
  2. test_kda_kernels.py的不完全适配:Copilot发现文件中第二个测试类TestKDAGateChunkCumsum仍全部硬编码device="cuda",本次修改仅覆盖了第一个类。作者未回应或修复,PR描述可能引起歧义。

实现拆解

实现步骤如下:

  1. 导入工具函数:在每个测试文件中添加from sglang.srt.utils.common import get_device
  2. 设备变量化:在类的setUpsetUpClass中通过get_device()获取当前可用设备,并保存为实例变量或类变量。原来硬编码的device="cuda"全部替换为device=self.devicedevice=cls._device
  3. 更新跳过条件:将@unittest.skipIf(not torch.cuda.is_available(), ...)改为@unittest.skipIf(not (torch.cuda.is_available() or torch.xpu.is_available()), ...),使测试在CUDA或XPU可用时均会执行(或跳过)。
  4. 调整默认设备:在TestScaledMM.setUpClass中同步修改torch.set_default_device为目标设备。
  5. 注意事项test_kda_kernels.py中第二个测试类TestKDAGateChunkCumsum仍使用硬编码'cuda',未作适配,需后续处理。
文件 模块 状态 重要度
test/registered/attention/test_kda_kernels.py KDA 内核测试 modified 5.23
test/registered/quant/test_triton_scaled_mm.py 缩放矩阵乘测试 modified 4.77
python/sglang/test/attention/test_prefix_chunk_info.py 前缀块测试 modified 4.37

关键符号

get_device TestKDAFusedSigmoidGatingRecurrent.setUp TestScaledMM.setUpClass TestScaledMM._make_inputs TestPrefixChunkInfo.setUp

关键源码片段

test/registered/attention/test_kda_kernels.py test-coverage

此文件是 KDA 内核测试,修改了第一个测试类,但第二个类未改动,体现了不完全适配,是 review 讨论的焦点。

import unittest
import torch
from sglang.srt.utils.common import get_device# 注意:此文件中第二个测试类 TestKDAGateChunkCumsum 仍完全硬编码 'cuda',本次未做任何修改。
# 对于 TestKDAFusedSigmoidGatingRecurrent,以下展示了设备适配的主要改动。@unittest.skipIf(
    not (torch.cuda.is_available() or torch.xpu.is_available()),
    "Test requires CUDA or XPU",
)
class TestKDAFusedSigmoidGatingRecurrent(unittest.TestCase):
    def setUp(self):
        self.device = get_device() # 动态获取当前可用设备
        self.token_num = 4
        # 原来的 device="cuda" 都替换为 device=self.device
        self.query_start_loc = torch.tensor([0, 1, 2, 3, 4], device=self.device)
        self.cache_indices = torch.tensor([0, 2, 5, 8], device=self.device)
        self.local_num_heads = 8
        self.head_dim = 128
        self.cache_len = 64
        self.A_log = torch.randn(1, 1, self.local_num_heads, 1, dtype=torch.float32, device=self.device)
        # 其他张量类似 ...
        self.ssm_states = torch.zeros(self.cache_len, self.local_num_heads, self.head_dim, self.head_dim, dtype=torch.float32, device=self.device)
​
    def run_fused(self):
        # 方法体未改动,仅设备已由 self.device 确定
        pass
​
    def run_kda(self):
        # 方法体未改动
        pass
test/registered/quant/test_triton_scaled_mm.py test-coverage

完整展示了 setUpClass 和 _make_inputs 的设备抽象,是典型的适配模式。

import unittest
from typing import Optional
import torch
import torch.testing
from sglang.srt.layers.quantization.fp8_kernel import triton_scaled_mm
from sglang.srt.utils.common import get_device
from sglang.test.test_utils import CustomTestCaseclass TestScaledMM(CustomTestCase):
    @classmethod
    def setUpClass(cls):
        # 同时检查 CUDA 和 XPU 可用性,避免直接 torch.cuda.is_available() 抛异常
        if not (torch.cuda.is_available() or torch.xpu.is_available()):
            raise unittest.SkipTest("No CUDA or XPU device available")
        cls._device = get_device() # 获取当前可用设备
        torch.set_default_device(cls._device) # 设置默认设备为获取的设备
​
    def _make_inputs(self, M, K, N, in_dtype):
        # 原来 device='cuda' 全部替换为 device=self._device
        if in_dtype == torch.int8:
            a = torch.randint(-8, 8, (M, K), dtype=in_dtype, device=self._device)
            b = torch.randint(-8, 8, (K, N), dtype=in_dtype, device=self._device)
        else: # fp8
            a = torch.clamp(
                0.1 * torch.randn((M, K), dtype=torch.float16, device=self._device),
                -0.3, 0.3,
            ).to(in_dtype)
            b = torch.clamp(
                0.1 * torch.randn((K, N), dtype=torch.float16, device=self._device),
                -0.3, 0.3,
            ).to(in_dtype)
        return a, b
​
    def test_basic_cases(self):
        # 测试逻辑不变,内部张量已通过 self._device 创建
        pass

评论区精华

torch.xpu.is_available 安全调用 正确性

Copilot 评论指出直接使用 torch.xpu.is_available() 可能导致 AttributeError,建议使用 hasattr 保护或 is_xpu()。

结论:未采纳,PR 合并时未修改。 · 待处理

TestKDAGateChunkCumsum 未适配 测试

Copilot 指出第二个测试类仍硬编码 cuda,描述与实际情况不符。

结论:未处理,PR 描述可能不准确。 · 待处理

风险与影响

技术风险包括:

  • 兼容性风险:三个文件的skipIf条件和setUpClass中直接使用torch.xpu.is_available(),在torch.xpu未定义的环境中会抛出AttributeError。目前CI环境可能已具备该属性,但未来扩展时需警惕。建议统一使用is_xpu()工具函数。
  • 测试覆盖遗漏test_kda_kernels.pyTestKDAGateChunkCumsum未适配,若在XPU CI中运行该文件,该测试类仍会因CUDA不可用而跳过(因为skipIf未更新),不会产生错误,但会导致测试覆盖盲区。
  • 回归风险低:对于纯CUDA环境,get_device()返回'cuda',行为与原来一致,不会引入回归。

影响范围限于三个测试文件,对用户无直接影响。对开发和CI团队:

  • 正面影响:这三个测试可以在XPU CI流水线中执行或跳过,扩大测试覆盖。
  • 负面影响:如果XPU环境配置不当,可能因AttributeError导致测试失败。但鉴于PR已合并,预计内部已验证。
  • 团队影响:为后续更多测试的XPU适配提供了可参考的模式。
安全调用 torch.xpu 部分测试未完全适配

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论