Prhub

#20234 Use device-agnostic helpers for Mamba tests and core ops

原始 PR 作者 roopaksrivastav 合并时间 2026-05-01 07:14 文件变更 8 提交数 15 评论 11 代码增减 +47 / -39

执行摘要

Mamba 测试与内核设备无关化以支持 XPU

Replace hardcoded CUDA device checks with device-agnostic helpers (get_device, get_device_count, torch.get_device_module) to support both CUDA and XPU devices.

值得关注 get_deviceget_device_count 等辅助函数的设计模式,可作为后续设备无关化重构的参考。

讨论亮点

PR 经历了两次 CI 检查,由 mingfeima 批准后合并。没有收到实质性 review 评论或设计争论。

实现拆解

  1. 在三个测试文件中,将 torch.cuda.is_available() 替换为 get_device() + 允许设备列表 ["cuda", "xpu"] 检查;将 torch.cuda.device_count() 替换为 get_device_count()
  2. 将测试函数和设备生成函数的默认参数从 device="cuda" 改为 device=None,并在函数内调用 get_device() 解析实际设备。
  3. 在多 GPU 测试中,将 torch.cuda.set_device() 替换为 torch.get_device_module(device).set_device()
  4. 在五个内核操作文件(layernorm_gated.pymamba_ssm.pyssd_bmm.pyssd_chunk_state.pyssd_state_passing.py)中,将 torch.cuda.device() 上下文管理器替换为 torch.get_device_module(device).device()
  5. 在测试文件中新增导入 from sglang.srt.utils import get_device, get_device_count
文件 模块 状态 重要度
test/registered/layers/mamba/test_mamba_ssm_ssd.py SSD 测试 modified 5.8
test/registered/layers/mamba/test_mamba_ssm.py SSM 测试 modified 5.16
test/registered/layers/mamba/test_mamba2_mixer.py Mixer 测试 modified 4.95
python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py SSD 块状态 modified 3.13
python/sglang/srt/layers/attention/mamba/ops/layernorm_gated.py 层归一化 modified 2.64
python/sglang/srt/layers/attention/mamba/ops/mamba_ssm.py Mamba SSM modified 2.64
python/sglang/srt/layers/attention/mamba/ops/ssd_bmm.py SSD 批乘 modified 2.64
python/sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py SSD 状态传递 modified 2.64

关键符号

generate_random_inputs test_selective_state_update test_selective_state_update_with_batch_indices test_mixer2_gated_norm_multi_gpu

关键源码片段

test/registered/layers/mamba/test_mamba_ssm_ssd.py test-coverage

测试文件改动最大,替换了所有设备检查和默认参数,并修改了 generate_random_inputs 函数签名

from sglang.srt.utils import get_devicedef generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype, device=None):
    """生成随机输入张量,device 为 None 时由 get_device() 自动选择。"""
    if device is None:
        device = get_device() # 自动适配当前设备(CUDA 或 XPU)
    # 仅允许在 CUDA 或 XPU 上运行,否则跳过测试
    if device not in ["cuda", "xpu"]:
        pytest.skip("Test only supports CUDA and XPU devices")
    torch.manual_seed(0)
    A = -torch.exp(torch.rand(n_heads, dtype=itype, device=device))
    dt = F.softplus(torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - 4)
    X = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
    B = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
    C = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
    return A, dt, X, B, Cdef generate_continuous_batched_examples(..., device=None, ...):
    A, dt, X, B, C = generate_random_inputs(
        num_examples, full_length, n_heads, d_head, itype, device
    )
    device = X.device # 从生成的张量捕获实际设备,确保后续一致

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

风险极低:所有替换都是机械的 API 映射,不改变原有计算逻辑。潜在风险包括:如果 get_device() 在特定环境返回意外值,可能导致测试在原本可用的设备上被跳过;XPU 支持仅通过允许列表实现,未覆盖所有可能后端,但已有 CI 覆盖 CUDA 和 AMD。

对现有 CUDA 用户无影响;XPU 用户可以运行 Mamba 测试和依赖这些内核的模型;整体提升了代码可移植性和多后端支持能力。影响范围限于 Mamba 相关测试和内核模块。

缺少 XPU 实际测试覆盖 get_device 行为依赖环境变量

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论