Prhub

#26517 Add attention-backend unit-test suite under test/registered/attention/unittest

原始 PR 作者 ch-wan 合并时间 2026-05-29 08:30 文件变更 69 提交数 2 评论 23 代码增减 +27422 / -0

执行摘要

为注意力后端添加模块级单元测试套件

现有 test/registered/attention/ 下的测试为服务器级端到端测试,后端回归时失败晚且难以隔离。本 PR 添加模块级正确性套件,在孤立环境中测试每个支持组合,失败时直接指向具体的后端方法和输入形状。

建议所有关注注意力后端的开发者仔细阅读本 PR 的测试架构,特别是参考实现的设计和 SWA decode 规则分类。后续新增后端时,务必在 dense_attention.py 中注册 SWA decode 规则,并按照已有模式添加测试文件。KNOWN_FAILURES.md 也是必读文档,用于理解当前后端限制。

讨论亮点

讨论 1:unittest 目录命名冲突

michaelzhang-ai 指出 test/registered/attention/unittest/__init__.py 使该目录成为普通包,当注册测试运行器执行每个测试时,该目录位于 sys.path[0],导致标准库 unittest 被阴影,引发 import unittest.mock 等错误。作者确认已在 PR #26654 中解决。

讨论 2:CI 注册 review

chatgpt-codex-connector[bot] 建议添加 register_cuda_ci 注册以避免 CI 发现 abort。作者在 PR body 中说明已通过 base-b-test-4-gpu-b200 和 base-b-test-1-gpu-large 两个 suite 自动注册,无需额外改动。

实现拆解

  1. 在 python/sglang/test/kits/ 下建立共享测试工具库,包括 mock_server_args.py、per-method fixture 和 runner modes。
  2. 在 test/registered/attention/unittest/ 下按 attention method 组织测试文件,每个测试文件对应一个后端,通过构建独立的 PyTorch 参考模块进行比较。
  3. 实现覆盖维度包括:attention methods(密集、SWA、MLA、DSA、DSV4、GDN、KDA、Lightning、Mamba2 等)、backends(FA3、FA4、FlashInfer、Triton、TorchNative 等)、forward modes(DECODE、EXTEND、MIXED、TARGET_VERIFY、DRAFT_EXTEND 等)、runner modes(eager、CUDA-graph、split-op、EAGLE 等)、speculative kinds(EAGLE chain/tree、FrozenKV MTP、DFlash、NGRAM)、input shapes(页边界、前缀精确页、不规则 batch 等)、cache layouts(contiguous、shuffled_pages、interleaved_pages、non_monotonic_extend)。
  4. 添加 KNOWN_FAILURES.md 文档记录每个已知后端限制及其需要采取的行动,以及 per-method README.md。
  5. 在 CI 中通过 base-b-test-4-gpu-b200 和 base-b-test-1-gpu-large 两个 suite 注册测试,无需更改 workflow 文件。
文件 模块 状态 重要度
python/sglang/test/kits/attention_unittest/attention_methods/dense_attention.py 密集注意力 added 7.76
python/sglang/test/kits/attention_unittest/attention_methods/dsa_attention.py DSA 注意力 added 7.48
python/sglang/test/kits/attention_unittest/attention_methods/dsv4_attention.py DSV4 注意力 added 7.48
python/sglang/test/kits/attention_unittest/runner_modes/speculative_draft_extend_runner.py 推测解码 added 7.48

关键符号

_swa_decode_uses_min_seq_len_rule make_dense_input_config_cases make_dense_attention_config_cases make_dsa_dense_fallback_cases make_dsa_sparse_cases make_dsv4_cases _make_eagle_draft_extend_input _make_frozen_kv_mtp_draft_extend_input _make_draft_extend_input _make_eagle_draft_extend_v2_input _prepare_draft_extend_batch _prepare_eagle_draft_extend_batch

关键源码片段

python/sglang/test/kits/attention_unittest/attention_methods/dense_attention.py test-coverage

提供基础注意力案例 dataclass 定义和 SWA decode 规则分类,是所有密集注意力测试的核心基础。

# 分类 SWA decode 使用的元数据规则,因生产后端不同而异。
# - min_seq_len_window: window_kv_lens = min(seq_lens, window)(不含当前 token)
# - extend_window: 允许 [query_pos - window, query_pos] 的 key(含当前 token)
_SWA_DECODE_MIN_SEQ_LEN_WINDOW: frozenset[str] = frozenset({"triton"})
_SWA_DECODE_EXTEND_WINDOW: frozenset[str] = frozenset(
    {"torch_native", "fa3", "fa4", "flex_attention", "trtllm_mha", "flashinfer"}
)def _swa_decode_uses_min_seq_len_rule(case: "DenseAttentionCase") -> bool:
    if case.backend in _SWA_DECODE_MIN_SEQ_LEN_WINDOW:
        return True
    if case.backend in _SWA_DECODE_EXTEND_WINDOW:
        return False
    raise ValueError(
        f"未知的 SWA decode 规则后端 {case.backend!r}。请根据其 "
        f"`init_forward_metadata_decode` 生成的元数据,将其添加到 "
        f"`_SWA_DECODE_MIN_SEQ_LEN_WINDOW` 或 `_SWA_DECODE_EXTEND_WINDOW`。"
    )@dataclass(frozen=True)
class DenseAttentionCase:
    name: str
    backend: str
    forward_mode: ForwardMode
    num_heads: int
    num_kv_heads: int
    page_size: int
    prefix_lens: tuple[int, ...]
    extend_lens: tuple[int, ...] = ()
    sliding_window_size: int | None = None
​
    @property
    def batch_size(self) -> int:
        return len(self.prefix_lens)
​
    @property
    def input_lens(self) -> tuple[int, ...]:
        if self.forward_mode.is_decode():
            return (1,) * self.batch_size
        return self.extend_lens
​
    @property
    def seq_lens(self) -> tuple[int, ...]:
        return tuple(p + q for p, q in zip(self.prefix_lens, self.input_lens))
​
    @property
    def num_input_tokens(self) -> int:
        return sum(self.input_lens)
python/sglang/test/kits/attention_unittest/attention_methods/dsa_attention.py test-coverage

定义 DSA 注意力密集回退和稀疏测试案例,覆盖页边界条件,是 DSA 后端测试的基础。

@dataclass(frozen=True)
class DSAAttentionCase(DenseAttentionCase):
    passDSA_PAGE_SIZE = 64
DSA_INDEX_TOPK = 8
DSA_SPARSE_ATOL = 1.6e-1
DSA_SPARSE_RTOL = 1.6e-1def make_dsa_dense_fallback_cases(backend: str) -> tuple[DSAAttentionCase, ...]:
    common = dict(
        backend=backend,
        forward_mode=ForwardMode.EXTEND,
        num_heads=4,
        num_kv_heads=4,
        page_size=DSA_PAGE_SIZE,
    )
    return (
        DSAAttentionCase(
            name="dsa_mha_one_shot_no_prefix_ragged",
            prefix_lens=(0, 0, 0),
            extend_lens=(3, 8, 17),
            **common,
        ),
        DSAAttentionCase(
            name="dsa_mha_one_shot_no_prefix_exact_page",
            prefix_lens=(0,),
            extend_lens=(DSA_PAGE_SIZE,),
            **common,
        ),
        # 零前缀 extend,序列长度刚好比页大小少 1,与精确页和跨页案例组合覆盖 < 页、= 页、> 页 三种情况
        DSAAttentionCase(
            name="dsa_mha_one_shot_no_prefix_seq_below_page",
            prefix_lens=(0,),
            extend_lens=(DSA_PAGE_SIZE - 1,),
            **common,
        ),
        # 混合 batch:低于、等于、高于页边界,测试密集回退路径的 K 写入不跨请求页表
        DSAAttentionCase(
            name="dsa_mha_one_shot_ragged_below_at_above_page",
            prefix_lens=(0, 0, 0),
            extend_lens=(DSA_PAGE_SIZE - 1, DSA_PAGE_SIZE, DSA_PAGE_SIZE + 1),
            **common,
        ),
    )

评论区精华

unittest 目录命名导致标准库阴影 正确性

michaelzhang-ai 指出 test/registered/attention/unittest/__init__.py 使该目录成为普通包,当注册测试运行器执行每个测试时,该目录位于 sys.path[0],导致标准库 unittest 被阴影,引发 import unittest.mock 错误。

结论:作者在 PR #26654 中通过调整解决了此问题。 · 已解决

测试文件 CI 注册缺失提醒 测试

chatgpt-codex-connector[bot] 审查时建议添加 register_cuda_ci 注册,否则 CI 发现会 abort。

结论:作者在 PR body 中说明已通过 base-b-test-4-gpu-b200 和 base-b-test-1-gpu-large 两个 suite 自动注册,无需额外改动。 · 已解决

风险与影响

  1. 命名冲突风险:已将目录命名为 unittest,与标准库冲突,已在 #26654 修复,但类似的命名问题仍需警惕。
  2. CI 时间增加:27 个测试文件预计约 595 秒,可能延长 CI 流水线总时长。
  3. 硬件依赖:部分测试需要 Blackwell GPU(SM>=100),在其他硬件上自动跳过,但可能降低非 Blackwell 环境的覆盖率。
  4. 参考实现漂移:随着后端变化,PyTorch 参考可能与实际后端行为偏离,需要维护。

对开发者:当引入新的注意力后端或修改现有后端时,可以立即运行对应的单元测试来验证正确性,大幅缩短回归定位时间。
对系统:减少生产环境出现注意力计算错误的风险。
对团队:建立了统一的测试框架,新增后端时只需按照 fixture 模式添加测试文件和案例即可。

命名冲突 CI 时间增加 硬件依赖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论