执行摘要
- 一句话:为三个测试文件添加XPU设备支持
- 推荐动作:此PR展示了在SGLang中为测试添加新硬件支持的标准化方法:使用
get_device()替代硬编码设备字符串,并相应调整跳过条件。虽然改动量小,但可以作为今后测试跨硬件适配的模板。建议阅读test_triton_scaled_mm.py的完整实现,以及review评论中关于安全调用torch.xpu的讨论,以避免类似问题。总体而言,值得快速浏览,但不需要深入精读。
功能与动机
为了支持Intel XPU设备,需要使单元测试能够无差别地在CUDA和XPU上执行。PR body中提到'Replace hardcoded 'cuda' device references with get_device() utility to enable tests to run on both CUDA and XPU devices.'
实现拆解
实现步骤如下:
- 导入工具函数:在每个测试文件中添加
from sglang.srt.utils.common import get_device。
- 设备变量化:在类的
setUp或setUpClass中通过get_device()获取当前可用设备,并保存为实例变量或类变量。原来硬编码的device="cuda"全部替换为device=self.device或device=cls._device。
- 更新跳过条件:将
@unittest.skipIf(not torch.cuda.is_available(), ...)改为@unittest.skipIf(not (torch.cuda.is_available() or torch.xpu.is_available()), ...),使测试在CUDA或XPU可用时均会执行(或跳过)。
- 调整默认设备:在
TestScaledMM.setUpClass中同步修改torch.set_default_device为目标设备。
- 注意事项:
test_kda_kernels.py中第二个测试类TestKDAGateChunkCumsum仍使用硬编码'cuda',未作适配,需后续处理。
关键文件:
test/registered/attention/test_kda_kernels.py(模块 KDA内核测试;类别 test;类型 test-coverage;符号 TestKDAFusedSigmoidGatingRecurrent, get_device): 此文件是KDA内核测试,修改了第一个测试类,但第二个类未改动,体现了不完全适配,是review讨论的焦点。
test/registered/quant/test_triton_scaled_mm.py(模块 缩放矩阵乘测试;类别 test;类型 test-coverage;符号 TestScaledMM, get_device): 完整展示了setUpClass和_make_inputs的设备抽象,是典型的适配模式。
python/sglang/test/attention/test_prefix_chunk_info.py(模块 前缀块测试;类别 test;类型 test-coverage;符号 TestPrefixChunkInfo, get_device): 修改了skip条件和device赋值,展示了最简单的适配模式。
关键符号:get_device, TestKDAFusedSigmoidGatingRecurrent.setUp, TestScaledMM.setUpClass, TestScaledMM._make_inputs, TestPrefixChunkInfo.setUp
关键源码片段
test/registered/attention/test_kda_kernels.py
此文件是KDA内核测试,修改了第一个测试类,但第二个类未改动,体现了不完全适配,是review讨论的焦点。
import unittest
import torch
from sglang.srt.utils.common import get_device
# 注意:此文件中第二个测试类 TestKDAGateChunkCumsum 仍完全硬编码 'cuda',本次未做任何修改。
# 对于 TestKDAFusedSigmoidGatingRecurrent,以下展示了设备适配的主要改动。
@unittest.skipIf(
not (torch.cuda.is_available() or torch.xpu.is_available()),
"Test requires CUDA or XPU",
)
class TestKDAFusedSigmoidGatingRecurrent(unittest.TestCase):
def setUp(self):
self.device = get_device() # 动态获取当前可用设备
self.token_num = 4
# 原来的 device="cuda" 都替换为 device=self.device
self.query_start_loc = torch.tensor([0, 1, 2, 3, 4], device=self.device)
self.cache_indices = torch.tensor([0, 2, 5, 8], device=self.device)
self.local_num_heads = 8
self.head_dim = 128
self.cache_len = 64
self.A_log = torch.randn(1, 1, self.local_num_heads, 1, dtype=torch.float32, device=self.device)
# 其他张量类似 ...
self.ssm_states = torch.zeros(self.cache_len, self.local_num_heads, self.head_dim, self.head_dim, dtype=torch.float32, device=self.device)
def run_fused(self):
# 方法体未改动,仅设备已由 self.device 确定
pass
def run_kda(self):
# 方法体未改动
pass
test/registered/quant/test_triton_scaled_mm.py
完整展示了setUpClass和_make_inputs的设备抽象,是典型的适配模式。
import unittest
from typing import Optional
import torch
import torch.testing
from sglang.srt.layers.quantization.fp8_kernel import triton_scaled_mm
from sglang.srt.utils.common import get_device
from sglang.test.test_utils import CustomTestCase
class TestScaledMM(CustomTestCase):
@classmethod
def setUpClass(cls):
# 同时检查 CUDA 和 XPU 可用性,避免直接 torch.cuda.is_available() 抛异常
if not (torch.cuda.is_available() or torch.xpu.is_available()):
raise unittest.SkipTest("No CUDA or XPU device available")
cls._device = get_device() # 获取当前可用设备
torch.set_default_device(cls._device) # 设置默认设备为获取的设备
def _make_inputs(self, M, K, N, in_dtype):
# 原来 device='cuda' 全部替换为 device=self._device
if in_dtype == torch.int8:
a = torch.randint(-8, 8, (M, K), dtype=in_dtype, device=self._device)
b = torch.randint(-8, 8, (K, N), dtype=in_dtype, device=self._device)
else: # fp8
a = torch.clamp(
0.1 * torch.randn((M, K), dtype=torch.float16, device=self._device),
-0.3, 0.3,
).to(in_dtype)
b = torch.clamp(
0.1 * torch.randn((K, N), dtype=torch.float16, device=self._device),
-0.3, 0.3,
).to(in_dtype)
return a, b
def test_basic_cases(self):
# 测试逻辑不变,内部张量已通过 self._device 创建
pass
评论区精华
Review评论主要来自Copilot,提出了两个问题:
torch.xpu.is_available()的安全调用:Copilot指出直接使用torch.xpu.is_available()在没有定义torch.xpu的PyTorch环境中可能导致AttributeError,建议用hasattr(torch, "xpu")或调用已有工具函数is_xpu()保护。此评论未得到作者或合并者回应,但PR仍被合并。
test_kda_kernels.py的不完全适配:Copilot发现文件中第二个测试类TestKDAGateChunkCumsum仍全部硬编码device="cuda",本次修改仅覆盖了第一个类。作者未回应或修复,PR描述可能引起歧义。
- torch.xpu.is_available安全调用 (correctness): 未采纳,PR合并时未修改。
- TestKDAGateChunkCumsum未适配 (testing): 未处理,PR描述可能不准确。
风险与影响
-
风险:技术风险包括:
-
兼容性风险:三个文件的skipIf条件和setUpClass中直接使用torch.xpu.is_available(),在torch.xpu未定义的环境中会抛出AttributeError。目前CI环境可能已具备该属性,但未来扩展时需警惕。建议统一使用is_xpu()工具函数。
- 测试覆盖遗漏:
test_kda_kernels.py中TestKDAGateChunkCumsum未适配,若在XPU CI中运行该文件,该测试类仍会因CUDA不可用而跳过(因为skipIf未更新),不会产生错误,但会导致测试覆盖盲区。
- 回归风险低:对于纯CUDA环境,
get_device()返回'cuda',行为与原来一致,不会引入回归。
-
影响:影响范围限于三个测试文件,对用户无直接影响。对开发和CI团队:
-
正面影响:这三个测试可以在XPU CI流水线中执行或跳过,扩大测试覆盖。
- 负面影响:如果XPU环境配置不当,可能因
AttributeError导致测试失败。但鉴于PR已合并,预计内部已验证。
- 团队影响:为后续更多测试的XPU适配提供了可参考的模式。
- 风险标记:安全调用torch.xpu, 部分测试未完全适配
关联脉络
- PR #23557 [Intel GPU] Integrate flash_mla_decode in Intel XPU attention backend: 同为Intel XPU支持,提供了注意力后端的XPU集成,本PR为相关测试提供设备适配。
参与讨论