执行摘要
- 一句话:强化CG-replay测试:生产填充、元数据不变式、多比例扫描
- 推荐动作:值得精读。本PR展示了如何通过分析测试/生产环境差异来设计有针对性的测试覆盖。
assert_cg_metadata_well_formed的设计原则(best-effort、静默跳过、单语句检查)和pad_style抽象值得在其他测试套件中复用。
功能与动机
Codex review of #26651发现了一个真实bug(CG replay时kv_lens为负),而现有测试因填充模式不匹配未能捕获。根本原因是测试用的填充模式(padded行的seq_lens设为capture_prefix_len + N)与生产环境(padded行的seq_lens设为seq_len_fill_value、extend_seq_lens设为num_tokens_per_bs,导致减出来为负)不同。本PR旨在缩小测试与生产环境之间的差距,并增加两层防御。
实现拆解
- 新增
metadata_invariants.py(python/sglang/test/kits/attention_unittest/runner_modes/metadata_invariants.py):定义 assert_cg_metadata_well_formed 函数,在CG replay元数据初始化后执行,检查forward_metadata上的indptr数组是否单调非递减且首元素为0、per-request长度是否非负。支持FA和Triton后端,缺失字段时静默跳过。
- 修改
speculative_cuda_graph_runner.py:为SpeculativeCudaGraphAdapter类新增pad_style字段(Literal["small_real", "prod_fill"])和pad_num_tokens_per_bs字段;实现_apply_prod_fill_padding函数,按生产版本的方式覆写padded行的seq_lens、extend_seq_lens、req_pool_indices、out_cache_loc、positions、input_ids等,使padded行的seq_lens - extend_seq_lens为负,考验backends的鲁棒性。
- 修改各测试文件(
test_fa3.py、test_fa4.py、test_triton.py、test_mla_triton.py)中的DRAFT_EXTEND_V2测试方法:从单个subTest扩展为pad_style ∈ {small_real, prod_fill} × capture_bs ∈ {1×, 2×, 4×}的6个subTest,覆盖0%、50%、75%填充比例。
- 配套调整:在
cuda_graph_decode_runner.py中导入并调用assert_cg_metadata_well_formed;在speculative_draft_extend_runner.py中添加pad_style参数透传。
关键文件:
python/sglang/test/kits/attention_unittest/runner_modes/metadata_invariants.py(模块 测试工具;类别 test;类型 test-coverage;符号 _slice_bs_plus_one, _slice_bs, assert_cg_metadata_well_formed): 新增文件,定义后端无关的CG replay元数据不变式检查函数 assert_cg_metadata_well_formed,是本次测试强化的核心工具。
python/sglang/test/kits/attention_unittest/runner_modes/speculative_cuda_graph_runner.py(模块 测试runner;类别 test;类型 test-coverage;符号 _apply_prod_fill_padding, PadStyle): 修改核心runner,新增生产风格填充函数和PadStyle类型,是测试生产模拟的关键入口。
test/registered/attention/unittests/dense/test_fa3.py(模块 FA3测试;类别 test;类型 test-coverage): DRAFT_EXTEND_V2测试用例被扩展为pad_style×capture_bs的6个subTest,是全套测试中覆盖比率扫描的主要场景之一。
test/registered/attention/unittests/dense/test_fa4.py(模块 FA4测试;类别 test;类型 test-coverage): 同上,FA4后端同等扩展。
test/registered/attention/unittests/dense/test_triton.py(模块 Triton测试;类别 test;类型 test-coverage): Triton dense后端增加pad_style和capture_bs参数。
test/registered/attention/unittests/mla/test_triton.py(模块 MLA测试;类别 test;类型 test-coverage): MLA Triton后端同样增加pad_style和capture_bs参数。
python/sglang/test/kits/attention_unittest/runner_modes/speculative_draft_extend_runner.py(模块 扩展runner;类别 test;类型 test-coverage): 透传pad_style参数到下游函数,使prod_fill模式生效。
python/sglang/test/kits/attention_unittest/runner_modes/cuda_graph_decode_runner.py(模块 解码runner;类别 test;类型 test-coverage): 在replay元数据初始化后调用assert_cg_metadata_well_formed,确保所有CG路径都能受益。
关键符号:assert_cg_metadata_well_formed, _apply_prod_fill_padding, _slice_bs_plus_one, _slice_bs
关键源码片段
python/sglang/test/kits/attention_unittest/runner_modes/metadata_invariants.py
新增文件,定义后端无关的CG replay元数据不变式检查函数 assert_cg_metadata_well_formed,是本次测试强化的核心工具。
"""Backend-agnostic assertions on `forward_metadata` after CG replay init.
Catches corruption that the output-equality assertion misses — e.g., negative
per-request lengths or non-monotonic indptr that happen to leave real-row
output correct while corrupting padded-row scratch state.
Usage from a CG runner kit:
from .metadata_invariants import assert_cg_metadata_well_formed
_init_cuda_graph_replay_metadata(backend, capture_batch_size, replay_batch)
assert_cg_metadata_well_formed(backend, bs=capture_batch_size)
"""
from __future__ import annotations
from typing import Any
import torch
# CSR indptr 字段名列表,期望非递减且首元素为 0
_INDPTR_FIELDS = (
"kv_indptr", "qo_indptr", "mask_indptr", "window_kv_indptr",
"cu_seqlens_q", "cu_seqlens_k", "encoder_cu_seqlens_k",
)
# 每个请求的长度字段列表,期望非负
_PER_REQ_LEN_FIELDS = (
"cache_seqlens_int32", "encoder_lens_int32", "local_seqused_k",
"max_seq_len_k", # 可能是标量,防御性检查
)
def _slice_bs_plus_one(t: torch.Tensor, bs: int) -> torch.Tensor:
# indptr 长度为 bs+1,有些后端预分配了 max_bs+1,取前 bs+1 个元素
return t[: bs + 1] if t.numel() >= bs + 1 else t
def _slice_bs(t: torch.Tensor, bs: int) -> torch.Tensor:
return t[:bs] if t.numel() >= bs else t
def assert_cg_metadata_well_formed(backend: Any, bs: int) -> None:
"""检查 backend.forward_metadata 是否明显损坏。
Best-effort: 仅当字段存在且为 tensor 时才检查。
"""
meta = getattr(backend, "forward_metadata", None)
if meta is None:
return
errors: list[str] = []
# 检查所有 indptr 字段:非递减且首元素为 0
for field in _INDPTR_FIELDS:
t = getattr(meta, field, None)
if not isinstance(t, torch.Tensor):
continue
sliced = _slice_bs_plus_one(t, bs)
if sliced.numel() < 2:
continue
diff = sliced[1:].to(torch.int64) - sliced[:-1].to(torch.int64)
# 允许零长度请求 (diff==0),拒绝负 diff
if (diff < 0).any().item():
min_diff = diff.min().item()
errors.append(f"{field} 在 bs={bs} 处非单调递减 (相邻最小差={min_diff}); 切片={sliced[:min(bs+1,16)].tolist()}")
if sliced[0].item() != 0:
errors.append(f"{field}[0] != 0 (为 {sliced[0].item()}) 在 bs={bs}")
# 检查所有 per-request 长度字段:非负
for field in _PER_REQ_LEN_FIELDS:
t = getattr(meta, field, None)
if not isinstance(t, torch.Tensor):
continue
sliced = _slice_bs(t, bs)
if sliced.numel() == 0:
continue
if (sliced < 0).any().item():
min_v = sliced.min().item()
errors.append(f"{field} 在 bs={bs} 处出现负值 (最小值={min_v}); 切片={sliced[:min(bs,16)].tolist()}")
if errors:
raise AssertionError(
"CG forward_metadata 不变式在 replay init 后被违反:
- "
+ "
- ".join(errors)
)
python/sglang/test/kits/attention_unittest/runner_modes/speculative_cuda_graph_runner.py
修改核心runner,新增生产风格填充函数和PadStyle类型,是测试生产模拟的关键入口。
def _apply_prod_fill_padding(
batch,
*,
real_bs: int,
capture_bs: int,
seq_len_fill_value: int,
num_tokens_per_bs: int,
) -> None:
"""按照生产 CG runner 的方式覆写 batch 中填充行的元数据。
生产环境将填充行设置为:seq_lens=fill, extend_seq_lens=N,
req_pool_indices=0, out_cache_loc=0, positions=0,
input_ids=0。这样 seq_lens - extend_seq_lens 会变为负值,
后端必须对此进行防御(如 clamp)。
"""
if real_bs >= capture_bs:
return
pad_lo, pad_hi = real_bs, capture_bs
# 覆写 per-request 长度张量
batch.seq_lens[pad_lo:pad_hi] = seq_len_fill_value
if batch.seq_lens_cpu is not None:
batch.seq_lens_cpu[pad_lo:pad_hi] = seq_len_fill_value
batch.seq_lens_sum = int(batch.seq_lens_cpu.sum())
if getattr(batch, "extend_seq_lens", None) is not None:
batch.extend_seq_lens[pad_lo:pad_hi] = num_tokens_per_bs
if getattr(batch, "extend_seq_lens_cpu", None) is not None:
ext = list(batch.extend_seq_lens_cpu)
for i in range(pad_lo, min(pad_hi, len(ext))):
ext[i] = num_tokens_per_bs
batch.extend_seq_lens_cpu = ext
# 覆写 per-request slot 张量
batch.req_pool_indices[pad_lo:pad_hi] = 0
# 覆写 per-token 张量:填充行占据 [real_bs*N, capture_bs*N) 切片
tok_lo = pad_lo * num_tokens_per_bs
tok_hi = pad_hi * num_tokens_per_bs
for field in ("out_cache_loc", "positions", "input_ids"):
t = getattr(batch, field, None)
if t is not None and t.numel() >= tok_hi:
t[tok_lo:tok_hi] = 0
# 同步 spec_info 中的 extend_seq_lens_tensor
spec_info = getattr(batch, "spec_info", None)
if spec_info is not None:
eslt = getattr(spec_info, "extend_seq_lens_tensor", None)
if isinstance(eslt, torch.Tensor) and eslt.numel() >= pad_hi:
eslt[pad_lo:pad_hi] = num_tokens_per_bs
eslc = getattr(spec_info, "extend_seq_lens_cpu", None)
if isinstance(eslc, (list, tuple)) and len(eslc) >= pad_hi:
ext2 = list(eslc)
for i in range(pad_lo, pad_hi):
ext2[i] = num_tokens_per_bs
spec_info.extend_seq_lens_cpu = ext2
评论区精华
PR无review评论。PR作者在body中详细解释了设计决策:选择生产风格填充而非简单模拟,是因为真实生产场景中padded行元数据不一致(seq_lens和extend_seq_lens组合会导致负值),必须被后端正确处理。metadata_invariants的检查是best-effort的,不会导致错误阻断,但对已暴露的字段执行强检查。
风险与影响
- 风险:低风险。本PR仅修改测试代码,不影响生产路径。但需注意:
- 新增的
assert_cg_metadata_well_formed可能在debug模式下引入额外开销,但每个CG replay只调用一次,影响可忽略。
pad_style="prod_fill"仅在DRAFT_EXTEND_V2路径使用,其他CG路径(TARGET_VERIFY、DRAFT_EXTEND_V1)未覆盖,未来需扩展。
- FA3/FA4的DRAFT_EXTEND_V2测试仍被skipTest跳过(已知问题),测试强化不影响现有skip逻辑。
- 影响:影响范围:仅限Attention Unittest套件(4-gpu-b200 CI)。新增约10个subTest,总subTest数从525增至535。对CUDA Graph replay路径的测试置信度显著提升,能捕获负kv_lens等之前遗漏的bug。团队可参考
metadata_invariants.py的设计模式为其他模块添加类似的不变式检查。
- 风险标记:仅覆盖DRAFT_EXTEND_V2, 新增断言可能假阳性但可忽略
关联脉络
- PR #26655 Fix TRTLLM MHA draft decode cache seqlens replay: 同为CG replay相关的bugfix,本PR的测试增强可防止同类问题
- PR #26628 Revert "Fix FA DRAFT_EXTEND_V2 cache extent": 涉及DRAFT_EXTEND_V2 cache extent修复的回滚,本PR增强了该路径的测试覆盖
参与讨论