Prhub

#20672 [MUSA][17/N] ci: Add MUSA diffusion, sgl-kernel tests, and CI workflow support

原始 PR 作者 johnnycxm 合并时间 2026-05-08 11:45 文件变更 14 提交数 2 评论 6 代码增减 +1549 / -29

执行摘要

为 MUSA 后端添加 CI 测试工作流与单元测试套件

该 PR 是 #16565 中描述的为 MUSA 添加完整支持的一部分,主要目标是为 MUSA 后端建立 CI 覆盖,确保多模态扩散和 sgl-kernel 功能在 MUSA 上持续验证,及早捕获回归,同时保持代码库在 CUDA、ROCm 和 MUSA 之间的统一。

该 PR 是新增硬件后端 CI 基础设施的重要贡献,值得所有关心多后端支持的开发者阅读。特别值得关注的是 run_suite_musa.py 的分区执行与失败重试设计,以及基准测试中条件导入 MUSA 核函数以兼容多后端的模式。

讨论亮点
  • yeahdongcn 指出 CI 工作流构建 sgl-kernel 依赖于尚未合并的 #17946,需要在合并后 rebase 再测试。
  • gemini-code-assist[bot] 建议在 test_musa_rmsnorm.pytest_musa_silu_and_mul.py 中使用 @pytest.mark.parametrize 替代 for 循环遍历数据类型,以提升测试可调试性。
  • 作者采纳建议后,yeahdongcn 审核并批准。

实现拆解

  1. 创建 MUSA CI 工作流:在 .github/workflows/pr-test-musa.yml 中定义了支持 pushpull_requestworkflow_dispatchworkflow_call 的工作流,包含 target_stagerun_all_tests 输入参数,可选择性运行特定阶段或全部 MUSA 测试。工作流触发多模态生成测试(1-GPU 和 2-GPU)以及 sgl-kernel 单元测试。

  2. 添加扩散模型 MUSA 测试套件:在 python/sglang/multimodal_gen/test/run_suite_musa.py 中定义了 1-gpu-musa2-gpu-musa 测试套件,并根据分区 ID 和总分片数执行并行测试采集与运行。对应的测试文件 test_server_a_musa.pytest_server_b_musa.pytest_server_2_gpu_a_musa.py 分别覆盖单卡图和视频扩散以及双卡场景。

  3. 添加层操作单元测试:在 python/sglang/multimodal_gen/test/layers/ 下新增 test_musa_rmsnorm.pytest_musa_silu_and_mul.py,通过调用 forward_musa 并与 forward_native 对比,验证 MUSA 自定义算子的正确性。覆盖多种隐藏层大小、数据类型(FP16/BF16/FP32)、2D/3D 形状、残差连接、零输入和大数值等边缘情况。

  4. 引入 MUSA 性能基线python/sglang/multimodal_gen/test/server/musa/perf_baselines_musa.json 记录了扩散模型各阶段(文本编码、去噪、解码等)的期望耗时,用于性能回归门控。

  5. 更新 sgl-kernel 基准测试:在 sgl-kernel/benchmark/bench_moe_topk_sigmoid.pybench_moe_topk_softmax.py 中,通过条件导入 MUSA 核函数,在 calculate_diff 和基准报告中加入 MUSA 实现比较,使基准测试可在 CUDA 和 MUSA 环境下灵活运行。

同时,scripts/ci/musa/musa_install_dependency.shpython/sglang/multimodal_gen/test/run_suite.py 也进行了相应调整以注册 MUSA 套件。

文件 模块 状态 重要度
python/sglang/multimodal_gen/test/layers/test_musa_rmsnorm.py RMSNorm 测试 added 8.05
python/sglang/multimodal_gen/test/run_suite_musa.py 测试套件 added 8.05
python/sglang/multimodal_gen/test/layers/test_musa_silu_and_mul.py SiLU+Mul 测试 added 7.51
sgl-kernel/benchmark/bench_moe_topk_sigmoid.py TopK 基准 modified 7.21
sgl-kernel/benchmark/bench_moe_topk_softmax.py TopK Softmax 基准 modified 7.1
.github/workflows/pr-test-musa.yml CI 工作流 added 6.38
python/sglang/multimodal_gen/test/server/musa/perf_baselines_musa.json 性能基线 added 5.92

关键符号

get_musa_device forward_musa forward_native _make_norm run_pytest collect_test_items parse_args musa_topk_sigmoid_fn musa_topk_softmax_fn sglang_topk_sigmoid sglang_topk_softmax calculate_diff

关键源码片段

python/sglang/multimodal_gen/test/layers/test_musa_rmsnorm.py test-coverage

新增的 MUSA RMSNorm 单元测试,覆盖多种隐藏层大小、数据类型、2D/3D 输入和残差路径,是验证 MUSA 算子正确性的核心测试。

# SPDX-License-Identifier: Apache-2.0
"""
Tests for MUSA-specific RMSNorm custom op.
These tests call forward_musa directly and compare against forward_native
as the reference implementation.
"""
import pytest
import torch# 仅在 MUSA 设备可用时运行整个模块
_musa_available = hasattr(torch, "musa") and torch.musa.is_available()
pytestmark = pytest.mark.skipif(not _musa_available, reason="MUSA device not available")SEED = 42
​
​
def get_musa_device():
    return torch.device("musa:0")
​
​
class TestRMSNorm:
    """Tests for RMSNorm.forward_musa vs forward_native."""
​
    @pytest.fixture(autouse=True)
    def setup(self):
        self.device = get_musa_device()
​
    def _make_norm(self, hidden_size, eps=1e-6, var_hidden_size=None):
        from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm
        norm = RMSNorm(hidden_size, eps=eps, var_hidden_size=var_hidden_size)
        norm = norm.to(self.device)
        return norm
​
    # 无残差分支测试:多种 hidden_size 和 dtype
    @pytest.mark.parametrize("hidden_size", [64, 128, 256, 512, 1024, 2048])
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
    def test_no_residual_matches_native(self, hidden_size, dtype):
        """forward_musa without residual should match forward_native."""
        norm = self._make_norm(hidden_size)
        torch.manual_seed(SEED)
        x = torch.randn(4, hidden_size, dtype=dtype, device=self.device)
​
        out_musa = norm.forward_musa(x.clone())
        out_native = norm.forward_native(x.clone())
​
        # 根据 dtype 调节容许误差
        atol = 1e-2 if dtype in (torch.float16, torch.bfloat16) else 1e-4
        rtol = 1e-2 if dtype in (torch.float16, torch.bfloat16) else 1e-4
        torch.testing.assert_close(out_musa, out_native, atol=atol, rtol=rtol)
​
    # 带残差分支测试
    @pytest.mark.parametrize("hidden_size", [64, 128, 256, 512, 1024])
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
    def test_with_residual_matches_native(self, hidden_size, dtype):
        """forward_musa with residual should match forward_native."""
        norm = self._make_norm(hidden_size)
        torch.manual_seed(SEED)
        x = torch.randn(4, hidden_size, dtype=dtype, device=self.device)
        residual = torch.randn(4, hidden_size, dtype=dtype, device=self.device)
​
        # 注意:forward_musa 会就地修改输入,所以先克隆
        x_musa, res_musa = x.clone(), residual.clone()
        x_native, res_native = x.clone(), residual.clone()
​
        out_musa, res_out_musa = norm.forward_musa(x_musa, res_musa)
        out_native, res_out_native = norm.forward_native(x_native, res_native)
​
        atol = 1e-2 if dtype in (torch.float16, torch.bfloat16) else 1e-4
        rtol = 1e-2 if dtype in (torch.float16, torch.bfloat16) else 1e-4
        torch.testing.assert_close(out_musa, out_native, atol=atol, rtol=rtol)
        torch.testing.assert_close(res_out_musa, res_out_native, atol=atol, rtol=rtol)
python/sglang/multimodal_gen/test/run_suite_musa.py test-coverage

MUSA 测试套件运行器,实现分区并行测试收集和失败重试逻辑,是 CI 执行 MUSA 测试的入口。

"""
Test runner for multimodal_gen MUSA suites that manages partitioned execution.
"""
import argparse
import subprocess
import sysfrom sglang.multimodal_gen.runtime.utils.logging_utils import init_loggerlogger = init_logger(__name__)# 定义 MUSA 测试套件,每套包含一组测试文件
SUITES = {
    "1-gpu-musa": [
        "musa/test_server_a_musa.py",
        "musa/test_server_b_musa.py",
    ],
    "2-gpu-musa": [
        "musa/test_server_2_gpu_a_musa.py",
    ],
}
​
​
def parse_args():
    """解析命令行参数,支持套件、分区、过滤等选项"""
    parser = argparse.ArgumentParser(description="Run multimodal_gen MUSA test suite")
    parser.add_argument("--suite", type=str, required=True, choices=list(SUITES.keys()),
                        help="The test suite to run")
    parser.add_argument("--partition-id", type=int, default=0,
                        help="Index of the current partition (for parallel execution)")
    parser.add_argument("--total-partitions", type=int, default=1,
                        help="Total number of partitions")
    parser.add_argument("-k", "--filter", type=str, default=None,
                        help="Pytest filter expression (passed to pytest -k)")
    parser.add_argument("--continue-on-error", action="store_true", default=False,
                        help="Continue running remaining tests even if one fails.")
    return parser.parse_args()
​
​
def collect_test_items(files, filter_expr=None):
    """使用 pytest --collect-only 收集测试项,返回节点 ID 列表"""
    cmd = [sys.executable, "-m", "pytest", "--collect-only", "-q"]
    if filter_expr:
        cmd.extend(["-k", filter_expr])
    cmd.extend(files)
​
    print(f"Collecting tests with command: {' '.join(cmd)}")
    result = subprocess.run(cmd, capture_output=True, text=True)
​
    # exit code 5 表示无测试被收集(可能是过滤导致的),视为正常
    if result.returncode not in (0, 5):
        error_msg = (
            f"pytest --collect-only failed with exit code {result.returncode}\n"
            f"Command: {' '.join(cmd)}\n"
        )
        if result.stderr:
            error_msg += f"stderr:\n{result.stderr}\n"
        if result.stdout:
            error_msg += f"stdout:\n{result.stdout}\n"
        logger.error(error_msg)
        raise RuntimeError(error_msg)
​
    if result.returncode == 5:
        print("No tests were collected (exit code 5). This may be expected with filters.")
​
    test_items = []
    for line in result.stdout.strip().split("\n"):
        line = line.strip()
        if line and "::" in line and not line.startswith(("=", "-", " ")):
            test_id = line.split()[0] if " " in line else line
            if "::" in test_id:
                test_items.append(test_id)
​
    print(f"Collected {len(test_items)} test items")
    return test_items

评论区精华

CI 构建依赖 #17946 other

yeahdongcn 指出 CI 工作流构建 sgl-kernel 依赖于尚未合并的 #17946,需要等待该 PR 合并后 rebase 再测试。

结论:#17946 已合并,作者成功 rebase,CI 正常工作。 · 已解决

使用 pytest.mark.parametrize 以提升测试隔离性 测试

gemini-code-assist[bot] 建议在 test_musa_rmsnorm.py 和 test_musa_silu_and_mul.py 中使用 @pytest.mark.parametrize 替代 for 循环遍历 dtype,使每个数据形成一个独立测试用例,方便定位失败。

结论:作者采纳建议,已修改代码。 · 已解决

风险与影响

  1. 硬件依赖性:MUSA 测试仅在配备摩尔线程 GPU 的 CI runner 上执行,其他环境自动跳过,可能掩盖跨后端通用逻辑的回归。
  2. 性能基线漂移perf_baselines_musa.json 中的期望值基于特定硬件抓取,若 CI 池硬件变更或驱动升级,可能导致误报或漏报。
  3. 维护成本:新增 14 个文件和 CI 工作流,增加了后续维护和调试的成本,特别是非 MUSA 环境下条件导入的完整性需要持续关注。
  • 用户:MUSA 设备用户将受益于持续验证,降低新版本中 MUSA 后端的退化风险。
  • 系统:多模态生成和 sgl-kernel 在 MUSA 上的正确性和性能得到保障。
  • 团队:需要维护 MUSA 特定的 CI 流程和测试套件,但有助于在早期捕获跨后端问题。
硬件依赖性 性能基线漂移 维护成本增加

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论