Prhub

#26658 test: strengthen CG-replay coverage with prod-fill padding, metadata invariants, and pad-ratio sweep

原始 PR 作者 ch-wan 合并时间 2026-05-29 13:43 文件变更 8 提交数 1 评论 5 代码增减 +323 / -25

执行摘要

强化 CG-replay 测试:生产填充、元数据不变式、多比例扫描

Codex review of #26651发现了一个真实bug(CG replay时kv_lens为负),而现有测试因填充模式不匹配未能捕获。根本原因是测试用的填充模式(padded行的seq_lens设为capture_prefix_len + N)与生产环境(padded行的seq_lens设为seq_len_fill_value、extend_seq_lens设为num_tokens_per_bs,导致减出来为负)不同。本PR旨在缩小测试与生产环境之间的差距,并增加两层防御。

值得精读。本PR展示了如何通过分析测试/生产环境差异来设计有针对性的测试覆盖。assert_cg_metadata_well_formed的设计原则(best-effort、静默跳过、单语句检查)和pad_style抽象值得在其他测试套件中复用。

讨论亮点

PR无review评论。PR作者在body中详细解释了设计决策:选择生产风格填充而非简单模拟,是因为真实生产场景中padded行元数据不一致(seq_lens和extend_seq_lens组合会导致负值),必须被后端正确处理。metadata_invariants的检查是best-effort的,不会导致错误阻断,但对已暴露的字段执行强检查。

实现拆解

  1. 新增metadata_invariants.pypython/sglang/test/kits/attention_unittest/runner_modes/metadata_invariants.py):定义 assert_cg_metadata_well_formed 函数,在CG replay元数据初始化后执行,检查forward_metadata上的indptr数组是否单调非递减且首元素为0、per-request长度是否非负。支持FA和Triton后端,缺失字段时静默跳过。
  2. 修改speculative_cuda_graph_runner.py:为SpeculativeCudaGraphAdapter类新增pad_style字段(Literal["small_real", "prod_fill"])和pad_num_tokens_per_bs字段;实现_apply_prod_fill_padding函数,按生产版本的方式覆写padded行的seq_lensextend_seq_lensreq_pool_indicesout_cache_locpositionsinput_ids等,使padded行的seq_lens - extend_seq_lens为负,考验backends的鲁棒性。
  3. 修改各测试文件(test_fa3.pytest_fa4.pytest_triton.pytest_mla_triton.py)中的DRAFT_EXTEND_V2测试方法:从单个subTest扩展为pad_style ∈ {small_real, prod_fill} × capture_bs ∈ {1×, 2×, 4×}的6个subTest,覆盖0%、50%、75%填充比例。
  4. 配套调整:在cuda_graph_decode_runner.py中导入并调用assert_cg_metadata_well_formed;在speculative_draft_extend_runner.py中添加pad_style参数透传。
文件 模块 状态 重要度
python/sglang/test/kits/attention_unittest/runner_modes/metadata_invariants.py 测试工具 added 7.23
python/sglang/test/kits/attention_unittest/runner_modes/speculative_cuda_graph_runner.py 测试 runner modified 6.34
test/registered/attention/unittests/dense/test_fa3.py FA3 测试 modified 5.27
test/registered/attention/unittests/dense/test_fa4.py FA4 测试 modified 5.27
test/registered/attention/unittests/dense/test_triton.py Triton 测试 modified 5.16
test/registered/attention/unittests/mla/test_triton.py MLA 测试 modified 5.11
python/sglang/test/kits/attention_unittest/runner_modes/speculative_draft_extend_runner.py 扩展 runner modified 4.78
python/sglang/test/kits/attention_unittest/runner_modes/cuda_graph_decode_runner.py 解码 runner modified 4.38

关键符号

assert_cg_metadata_well_formed _apply_prod_fill_padding _slice_bs_plus_one _slice_bs

关键源码片段

python/sglang/test/kits/attention_unittest/runner_modes/metadata_invariants.py test-coverage

新增文件,定义后端无关的 CG replay 元数据不变式检查函数 `assert_cg_metadata_well_formed`,是本次测试强化的核心工具。

"""Backend-agnostic assertions on `forward_metadata` after CG replay init.Catches corruption that the output-equality assertion misses — e.g., negative
per-request lengths or non-monotonic indptr that happen to leave real-row
output correct while corrupting padded-row scratch state.Usage from a CG runner kit:    from .metadata_invariants import assert_cg_metadata_well_formed
    _init_cuda_graph_replay_metadata(backend, capture_batch_size, replay_batch)
    assert_cg_metadata_well_formed(backend, bs=capture_batch_size)
"""from __future__ import annotations
from typing import Any
import torch# CSR indptr 字段名列表,期望非递减且首元素为 0
_INDPTR_FIELDS = (
    "kv_indptr", "qo_indptr", "mask_indptr", "window_kv_indptr",
    "cu_seqlens_q", "cu_seqlens_k", "encoder_cu_seqlens_k",
)# 每个请求的长度字段列表,期望非负
_PER_REQ_LEN_FIELDS = (
    "cache_seqlens_int32", "encoder_lens_int32", "local_seqused_k",
    "max_seq_len_k", # 可能是标量,防御性检查
)def _slice_bs_plus_one(t: torch.Tensor, bs: int) -> torch.Tensor:
    # indptr 长度为 bs+1,有些后端预分配了 max_bs+1,取前 bs+1 个元素
    return t[: bs + 1] if t.numel() >= bs + 1 else tdef _slice_bs(t: torch.Tensor, bs: int) -> torch.Tensor:
    return t[:bs] if t.numel() >= bs else tdef assert_cg_metadata_well_formed(backend: Any, bs: int) -> None:
    """检查 backend.forward_metadata 是否明显损坏。
    Best-effort: 仅当字段存在且为 tensor 时才检查。
    """
    meta = getattr(backend, "forward_metadata", None)
    if meta is None:
        return
​
    errors: list[str] = []
​
    # 检查所有 indptr 字段:非递减且首元素为 0
    for field in _INDPTR_FIELDS:
        t = getattr(meta, field, None)
        if not isinstance(t, torch.Tensor):
            continue
        sliced = _slice_bs_plus_one(t, bs)
        if sliced.numel() < 2:
            continue
        diff = sliced[1:].to(torch.int64) - sliced[:-1].to(torch.int64)
        # 允许零长度请求 (diff==0),拒绝负 diff
        if (diff < 0).any().item():
            min_diff = diff.min().item()
            errors.append(f"{field} 在 bs={bs} 处非单调递减 (相邻最小差={min_diff}); 切片={sliced[:min(bs+1,16)].tolist()}")
        if sliced[0].item() != 0:
            errors.append(f"{field}[0] != 0 (为 {sliced[0].item()}) 在 bs={bs}")
​
    # 检查所有 per-request 长度字段:非负
    for field in _PER_REQ_LEN_FIELDS:
        t = getattr(meta, field, None)
        if not isinstance(t, torch.Tensor):
            continue
        sliced = _slice_bs(t, bs)
        if sliced.numel() == 0:
            continue
        if (sliced < 0).any().item():
            min_v = sliced.min().item()
            errors.append(f"{field} 在 bs={bs} 处出现负值 (最小值={min_v}); 切片={sliced[:min(bs,16)].tolist()}")
​
    if errors:
        raise AssertionError(
            "CG forward_metadata 不变式在 replay init 后被违反:\n  - "
            + "\n  - ".join(errors)
        )
python/sglang/test/kits/attention_unittest/runner_modes/speculative_cuda_graph_runner.py test-coverage

修改核心 runner,新增生产风格填充函数和 PadStyle 类型,是测试生产模拟的关键入口。

def _apply_prod_fill_padding(
    batch,
    *,
    real_bs: int,
    capture_bs: int,
    seq_len_fill_value: int,
    num_tokens_per_bs: int,
) -> None:
    """按照生产 CG runner 的方式覆写 batch 中填充行的元数据。    生产环境将填充行设置为:seq_lens=fill, extend_seq_lens=N,
    req_pool_indices=0, out_cache_loc=0, positions=0,
    input_ids=0。这样 seq_lens - extend_seq_lens 会变为负值,
    后端必须对此进行防御(如 clamp)。
    """
    if real_bs >= capture_bs:
        return
    pad_lo, pad_hi = real_bs, capture_bs
​
    # 覆写 per-request 长度张量
    batch.seq_lens[pad_lo:pad_hi] = seq_len_fill_value
    if batch.seq_lens_cpu is not None:
        batch.seq_lens_cpu[pad_lo:pad_hi] = seq_len_fill_value
        batch.seq_lens_sum = int(batch.seq_lens_cpu.sum())
​
    if getattr(batch, "extend_seq_lens", None) is not None:
        batch.extend_seq_lens[pad_lo:pad_hi] = num_tokens_per_bs
    if getattr(batch, "extend_seq_lens_cpu", None) is not None:
        ext = list(batch.extend_seq_lens_cpu)
        for i in range(pad_lo, min(pad_hi, len(ext))):
            ext[i] = num_tokens_per_bs
        batch.extend_seq_lens_cpu = ext
​
    # 覆写 per-request slot 张量
    batch.req_pool_indices[pad_lo:pad_hi] = 0
​
    # 覆写 per-token 张量:填充行占据 [real_bs*N, capture_bs*N) 切片
    tok_lo = pad_lo * num_tokens_per_bs
    tok_hi = pad_hi * num_tokens_per_bs
    for field in ("out_cache_loc", "positions", "input_ids"):
        t = getattr(batch, field, None)
        if t is not None and t.numel() >= tok_hi:
            t[tok_lo:tok_hi] = 0
​
    # 同步 spec_info 中的 extend_seq_lens_tensor
    spec_info = getattr(batch, "spec_info", None)
    if spec_info is not None:
        eslt = getattr(spec_info, "extend_seq_lens_tensor", None)
        if isinstance(eslt, torch.Tensor) and eslt.numel() >= pad_hi:
            eslt[pad_lo:pad_hi] = num_tokens_per_bs
        eslc = getattr(spec_info, "extend_seq_lens_cpu", None)
        if isinstance(eslc, (list, tuple)) and len(eslc) >= pad_hi:
            ext2 = list(eslc)
            for i in range(pad_lo, pad_hi):
                ext2[i] = num_tokens_per_bs
            spec_info.extend_seq_lens_cpu = ext2

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

低风险。本PR仅修改测试代码,不影响生产路径。但需注意:

  • 新增的assert_cg_metadata_well_formed可能在debug模式下引入额外开销,但每个CG replay只调用一次,影响可忽略。
  • pad_style="prod_fill"仅在DRAFT_EXTEND_V2路径使用,其他CG路径(TARGET_VERIFY、DRAFT_EXTEND_V1)未覆盖,未来需扩展。
  • FA3/FA4的DRAFT_EXTEND_V2测试仍被skipTest跳过(已知问题),测试强化不影响现有skip逻辑。

影响范围:仅限Attention Unittest套件(4-gpu-b200 CI)。新增约10个subTest,总subTest数从525增至535。对CUDA Graph replay路径的测试置信度显著提升,能捕获负kv_lens等之前遗漏的bug。团队可参考metadata_invariants.py的设计模式为其他模块添加类似的不变式检查。

仅覆盖 DRAFT_EXTEND_V2 新增断言可能假阳性但可忽略

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论