执行摘要
- 一句话:新增确定性 token oracle 与 write-input 断言
- 推荐动作:该 PR 设计清晰,分层合理,值得精读。特别是 Token Oracle 的确定性哈希选择、perturb 与断言分离的测试策略、以及通过环境变量精细控制功能启用的思路,可作为类似可观测性功能的参考模板。
功能与动机
PR body 提到需要让 kv-canary 的 write-input 断言能够验证模型消费了正确的 input_ids/positions,而前提是有一个确定性的 token oracle 来预先知道每个请求的采样结果。此外,通过扰动机制可以证明断言检查链路在真实运行时是活跃的。
实现拆解
- 确定性 Token Oracle 核心:在
python/sglang/srt/kv_canary/token_oracle/oracle.py 中定义 TokenOracle 协议和 HashOracle 实现,使用 SplitMix64 哈希将 (generalized_req_id, position) 确定性地映射到 token_id。
- TokenOracleManager 与采样器:在
oracle_manager.py 实现 TokenOracleManager,负责填充预期输入(fill_expected_inputs)和生成下一 token(sample_next_tokens);在 sampler.py 中注册 "token_oracle" 采样后端,替换标准采样逻辑,并通过 install_oracle_sampler 注入依赖。
- Perturb 扰动机制:在
next_token_swap.py 中实现 maybe_perturb_swap_next_tokens,以指定概率随机交换批内两个请求的采样 token,仅影响采样输出,不碰 KV 路径,用于验证 write-input 断言能检测到输入不一致。
- Write-Input 断言集成:在
single_forward_manager/manager.py 中通过 _should_enable_write_input_assert_for_launch 判断是否启用断言(默认关闭,且对 EAGLE draft 解码步骤跳过),利用 TokenOracleManager.fill_expected_inputs 填充预期值,在写入内核执行后比较实际 input_ids/positions。
- 环境变量与配置:通过
server_args.py 注册 "token_oracle" 采样后端,在 environ.py 新增 SGLANG_KV_CANARY_ENABLE_TOKEN_ORACLE、SGLANG_KV_CANARY_ENABLE_WRITE_INPUT_ASSERT、SGLANG_KV_CANARY_PERTURB_NEXT_TOKEN_SWAP_PROB、SGLANG_KV_CANARY_PERTURB_WARMUP_STEPS 等变量;CanaryConfig 中新增 enable_write_input_assert 配置项。
- 测试覆盖:包括 oracle 单元测试、mock 连线测试、token oracle manager 行为测试、perturb e2e 测试、EAGLE positions 回归测试等。
关键文件:
python/sglang/srt/kv_canary/token_oracle/oracle.py(模块 KVCache;类别 source;类型 core-logic;符号 TokenOracle, expected_tokens, HashOracle, _splitmix64_tensor): 定义 TokenOracle 协议和 HashOracle 实现,是确定性 token 映射的核心算法。
python/sglang/srt/kv_canary/token_oracle/oracle_manager.py(模块 KVCache;类别 source;类型 dependency-wiring;符号 TokenOracleManager, init, fill_expected_inputs, sample_next_tokens): 管理 TokenOracle 实例,提供 fill_expected_inputs 和 sample_next_tokens 两个核心方法,并处理不同 forward mode 下 generalized_req_id 的扩展。
python/sglang/srt/kv_canary/token_oracle/sampler.py(模块 KVCache;类别 source;类型 dependency-wiring;符号 install_oracle_sampler, _OracleSampler, init, forward): 将 TokenOracleManager 注册为采样后端,替换标准采样逻辑,并集成 next-token-swap 扰动。
test/registered/mock_model/test_self_unit_canary_mock_wiring.py(模块 单元测试;类别 test;类型 test-coverage;符号 _StubForwardBatch, _scalar_expected_token, TestFillExpectedInputs, test_sample_next_tokens_uses_next_position): 测试 TokenOracleManager 核心逻辑(fill_expected_inputs, sample_next_tokens),覆盖多种 forward mode。
test/registered/mock_model/test_self_unit_oracle.py(模块 单元测试;类别 test;类型 test-coverage;符号 _signed_to_unsigned_i64, _call, TestHashOracle, test_hash_oracle_is_deterministic_for_same_inputs): 验证 HashOracle 的确定性和范围正确性,并对比 tensor 实现与标量 SplitMix64 参考实现。
python/sglang/srt/kv_canary/perturb/next_token_swap.py(模块 KVCache;类别 source;类型 core-logic;符号 NextTokenSwapConfig, from_env, _get_config, maybe_perturb_swap_next_tokens): 实现 next-token-swap 扰动,用于测试断言链路活性,是测试基础设施的一部分。
python/sglang/srt/kv_canary/single_forward_manager/manager.py(模块 KVCache;类别 source;类型 dependency-wiring;符号 _should_enable_write_input_assert_for_launch): 整合 write-input 断言到前向传播管理器中,对接 TokenOracleManager 和写内核。
python/sglang/srt/kv_canary/token_oracle/install.py(模块 KVCache;类别 source;类型 core-logic;符号 install_token_oracle_from_env): 提供环境变量驱动的安装入口,简化外部调用。
test/registered/kv_canary/test_self_e2e_pr_25015.py(模块 端到端测试;类别 test;类型 test-coverage;符号 _EaglePositionsBase, setUpClass, test_pr_25015_eagle_positions, TestEaglePositionsMisalignRegression): 回归测试 EAGLE positions 对齐问题,确保 token oracle 与 positions 配合正确。
test/registered/mock_model/test_self_unit_install.py(模块 单元测试;类别 test;类型 test-coverage;符号 _make_server_args, TestInstallTokenOracleFromEnv, test_install_token_oracle_from_env_disabled_returns_none, test_install_token_oracle_from_env_enabled_registers_oracle_backend): 测试 install_token_oracle_from_env 的开关逻辑和正确性。
test/registered/mock_model/test_self_unit_oracle_torch_vs_ref.py(模块 单元测试;类别 test;类型 test-coverage;符号 TestHashOracleTorchVsRef, test_hash_oracle_matches_scalar_splitmix64_ref, test_hash_oracle_batched_matches_scalar_splitmix64_ref): 对比 PyTorch 实现的 SplitMix64 与标量参考实现,确保数值一致。
test/registered/kv_canary/test_self_unit_token_oracle.py(模块 单元测试;类别 test;类型 test-coverage;符号 TestTokenOracleManager, setUp, test_fill_expected_inputs_expands_draft_extend_generalized_req_ids_per_token): 测试 TokenOracleManager 的扩展逻辑(draft_extend 模式)。
关键符号:TokenOracleManager.fill_expected_inputs, TokenOracleManager.sample_next_tokens, HashOracle.expected_tokens, _OracleSampler.forward, maybe_perturb_swap_next_tokens, install_oracle_sampler, install_token_oracle_from_env, _should_enable_write_input_assert_for_launch
关键源码片段
python/sglang/srt/kv_canary/token_oracle/oracle.py
定义 TokenOracle 协议和 HashOracle 实现,是确定性 token 映射的核心算法。
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
import torch
class TokenOracle(Protocol):
"""Deterministic (generalized_req_id, position) -> token_id mapping."""
def expected_tokens(
self, *, generalized_req_ids: torch.Tensor, positions: torch.Tensor
) -> torch.Tensor: ...
@dataclass(frozen=True, slots=True, kw_only=True)
class HashOracle:
"""token_id = splitmix64(generalized_req_id XOR position) % vocab_size."""
vocab_size: int
def expected_tokens(
self, *, generalized_req_ids: torch.Tensor, positions: torch.Tensor
) -> torch.Tensor:
# 输入广义请求 ID 和位置,XOR 后经过 SplitMix64 哈希,再取模词表大小
x = generalized_req_ids.to(torch.int64) ^ positions.to(torch.int64)
x = _splitmix64_tensor(x)
return _uint64_mod(x, self.vocab_size).to(torch.int32)
_C1: int = -4658895280553007687 # 0xBF58476D1CE4E5B9 作为有符号 int64
_C2: int = -7723592293110705685 # 0x94D049BB133111EB 作为有符号 int64
def _splitmix64_tensor(x: torch.Tensor) -> torch.Tensor:
# 对张量逐元素执行 SplitMix64 混合步骤
x = (x ^ _logical_shr(x, 30)) * _C1
x = (x ^ _logical_shr(x, 27)) * _C2
x = x ^ _logical_shr(x, 31)
return x
def _logical_shr(x: torch.Tensor, n: int) -> torch.Tensor:
# 逻辑右移,用于 PyTorch 有符号整数的无符号移位
return (x >> n) & ((1 << (64 - n)) - 1)
def _uint64_mod(x: torch.Tensor, mod: int) -> torch.Tensor:
# 正确处理 PyTorch 有符号 int64 取模,使其行为如同无符号
offset = (1 << 64) % mod
base = x % mod
correction = (x < 0).to(x.dtype) * offset
return (base + correction) % mod
python/sglang/srt/kv_canary/token_oracle/oracle_manager.py
管理 TokenOracle 实例,提供 fill_expected_inputs 和 sample_next_tokens 两个核心方法,并处理不同 forward mode 下 generalized_req_id 的扩展。
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from sglang.srt.kv_canary.expected_inputs import ExpectedInputs
from sglang.srt.kv_canary.token_oracle.oracle import TokenOracle
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class TokenOracleManager:
def __init__(self, *, oracle: TokenOracle) -> None:
self.oracle = oracle
def fill_expected_inputs(
self, *, forward_batch: "ForwardBatch", expected_inputs_out: ExpectedInputs
) -> None:
# 从 forward batch 获取当前输入和位置
positions = forward_batch.positions
input_ids = forward_batch.input_ids
num_tokens = int(input_ids.shape[0])
if num_tokens == 0:
return
# 构建每个 token 对应的广义请求 ID(考虑 bootstrap room 和不同 forward mode 的展开)
generalized_req_ids = _build_generalized_req_id_per_token(
forward_batch=forward_batch, num_tokens=num_tokens,
generalized_req_ids_per_row=select_generalized_req_ids(
vanilla_req_ids=forward_batch.rids_int,
bootstrap_room_ids_int=forward_batch.bootstrap_room_ids_int,
),
)
# extend 模式下实际 input_ids 就是预期 token;decode 模式通过 oracle 确定
if forward_batch.forward_mode.is_extend():
expected_tokens = input_ids
else:
expected_tokens = self.oracle.expected_tokens(
generalized_req_ids=generalized_req_ids,
positions=positions.to(torch.int64),
)
# 将预期值和实际位置写入输出缓冲区
expected_inputs_out.tokens[:num_tokens].copy_(expected_tokens.to(torch.int64))
expected_inputs_out.positions[:num_tokens].copy_(positions.to(torch.int64))
def sample_next_tokens(
self, *, generalized_req_ids: torch.Tensor, logits_positions: torch.Tensor
) -> torch.Tensor:
# 采样下一 token 时位置加 1,因为预测的是下一个位置
return self.oracle.expected_tokens(
generalized_req_ids=generalized_req_ids,
positions=logits_positions.to(torch.int64) + 1,
)
# 其余辅助函数(_build_generalized_req_id_per_token、_expand_uniform、select_generalized_req_ids)省略
python/sglang/srt/kv_canary/token_oracle/sampler.py
将 TokenOracleManager 注册为采样后端,替换标准采样逻辑,并集成 next-token-swap 扰动。
from __future__ import annotations
from typing import TYPE_CHECKING, List
import torch
from sglang.srt.kv_canary.perturb.next_token_swap import maybe_perturb_swap_next_tokens
from sglang.srt.kv_canary.token_oracle.oracle import TokenOracle
from sglang.srt.kv_canary.token_oracle.oracle_manager import (
TokenOracleManager, select_generalized_req_ids,
)
from sglang.srt.layers.sampler import Sampler, register_sampler_backend
if TYPE_CHECKING:
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
def install_oracle_sampler(*, oracle: TokenOracle) -> TokenOracleManager:
"""注册 token_oracle 采样后端并返回 manager 供外部使用。"""
manager = TokenOracleManager(oracle=oracle)
register_sampler_backend(
"token_oracle",
lambda: _OracleSampler(token_oracle_manager=manager),
)
return manager
class _OracleSampler(Sampler):
def __init__(self, *, token_oracle_manager: TokenOracleManager) -> None:
super().__init__()
self._token_oracle_manager = token_oracle_manager
def forward(
self,
logits_output: "LogitsProcessorOutput",
sampling_info: "SamplingBatchInfo",
return_logprob: bool,
top_logprobs_nums: List[int],
token_ids_logprobs: List[List[int]],
positions: torch.Tensor,
) -> torch.Tensor:
# 使用 sampling_info 的 rids_int 作为 generalized_req_id 来源
vanilla_req_ids = sampling_info.rids_int
if vanilla_req_ids is None:
raise RuntimeError(
"_OracleSampler.forward: generalized_req_id source tensor is None; "
"token oracle requires a per-forward generalized_req_id source tensor "
"(set in ForwardBatch.init_new when SGLANG_KV_CANARY_ENABLE_TOKEN_ORACLE=1)"
)
# 通过 oracle manager 获取下一 token(基于当前位置 +1)
batch_next_token_ids = self._token_oracle_manager.sample_next_tokens(
generalized_req_ids=select_generalized_req_ids(
vanilla_req_ids=vanilla_req_ids,
bootstrap_room_ids_int=sampling_info.bootstrap_room_ids_int,
),
logits_positions=positions,
)
# 可选扰动:交换两个请求的 token 以验证断言链路
batch_next_token_ids = maybe_perturb_swap_next_tokens(batch_next_token_ids)
return batch_next_token_ids
评论区精华
风险与影响
- 风险:所有新功能均默认关闭,对现有推理路径无影响。启用时风险包括:1)确定性哈希算法(SplitMix64)的正确性依赖,如果存在未发现的 bug 可能导致错误 passed;2)采样后端替换可能与其他自定义采样器冲突;3)write-input 断言增加了每次前向传播的额外计算开销(但仅比较张量,开销很低);4)perturb 扰动仅用于测试,若误在生产环境启用可能引入随机错误。
- 影响:对用户无感知,所有功能默认关闭。对系统增加了 kv-canary 的可观测性能力,使开发者能够端到端验证 KV 缓存写入的正确性。对团队提供了健壮的测试手段(perturb)来证明断言机制有效。影响范围限于 kv-canary 子系统,不涉及核心推理路径。
- 风险标记:默认关闭无影响, 确定性哈希算法正确性依赖, 采样后端替换可能冲突, Perturb 仅用于测试
关联脉络
- PR #26816 [kv-canary] Add the KV-canary perturb framework for fault-injection self-tests: 本 PR 的 next-token-swap 扰动构建在 #26816 引入的 perturb 框架之上。
参与讨论