执行摘要
- 一句话:添加KV-canary核心:数据层、KV池修补器和前向运行器
- 推荐动作:此 PR 是 KV-canary 系列的基础,建议精读理解设计模式:状态管理、池修补、前向钩子集成。关注 assert 替换为 Exception 的讨论,这是生产代码的重要稳健性考量。
功能与动机
PR 描述指出这是一个自包含的金丝雀核心,后续将分 PR 添加更多池支持和 E2E 测试。动机是为 sglang 提供一种可插拔的运行时验证机制,用于检测 KV-cache 中的数据损坏。
实现拆解
- 数据层与状态管理:新增
config、state、buffer_group、capacities 等文件,定义金丝雀的配置和运行时状态,包括 CanaryDeviceState 和 ViolationLog。
- KV-pool 修补器:在
pool_patcher/buf_info_splice.py 中修补 KV-pool 的缓冲区信息方法,插入金丝雀缓冲区视图。
- 端点与运行器:实现
CanaryEndpoint(代表一个验证-写对)和 SingleForwardManager(管理单次 forward 的验证和写入计划),以及顶层的 CanaryManager 协调多个 forward 管理器。
- 验证与违规报告:
ViolationReporter 和 ViolationManager 负责从 GPU 读取违规环形缓冲区并在违规时记录或抛出异常。ViolationReporter 支持 LOG 和 RAISE 两种模式。
- 集成与测试:修改
forward_batch_info.py 以导入 envs(可能用于环境变量),新增 18+ 个测试文件覆盖单元和集成测试,验证金丝雀管线的正确性。
关键文件:
python/sglang/srt/kv_canary/single_forward_manager/manager.py(模块 前向管理器;类别 source;类型 core-logic;符号 SingleForwardManager, _SingleForwardPhase, _PreOpsMaybeInsideGraphOutput, pre_ops_outside_graph): 定义了 SingleForwardManager 管理单次 forward 的验证/写入计划,是核心运行器之一。
python/sglang/srt/kv_canary/runner/canary_manager.py(模块 金丝雀运行器;类别 source;类型 core-logic;符号 CanaryManager, with_active_single_forward_manager, pre_ops_maybe_inside_graph, post_ops_maybe_inside_graph): 顶层协调器 CanaryManager,负责生命周期、端点构建和 forward 钩子的调度。
python/sglang/srt/kv_canary/endpoint.py(模块 端点;类别 source;类型 core-logic;符号 CanaryEndpoint, launch_per_forward, _make_verify_or_write_context, build_endpoints_from_group): 定义 CanaryEndpoint,封装单个验证/写对的启动逻辑。
python/sglang/srt/kv_canary/runner/violation_reporter.py(模块 违规报告器;类别 source;类型 core-logic;符号 ViolationReporter, log_or_raise_violation, _format_violation): 负责从 GPU 环形缓冲区读取违规并记录/抛出,是验证结果输出的关键。
python/sglang/srt/kv_canary/state.py(模块 状态管理;类别 source;类型 dependency-wiring;符号 CanaryDeviceState, ViolationLog, allocate): 定义 CanaryDeviceState 和 ViolationLog,管理设备端计数器和违规环形缓冲区的分配。
python/sglang/srt/kv_canary/capacities.py(模块 容量配置;类别 source;类型 dependency-wiring;符号 CanaryLaunchCapacities, from_args): 定义 CanaryLaunchCapacities,配置每次 forward 的 verify/write 容量。
关键符号:SingleForwardManager.init, SingleForwardManager.pre_ops_outside_graph, SingleForwardManager.pre_ops_maybe_inside_graph, CanaryManager.init, CanaryManager.with_active_single_forward_manager, CanaryManager.pre_ops_maybe_inside_graph, CanaryEndpoint.launch_per_forward, CanaryEndpoint._make_verify_or_write_context, build_endpoints_from_group, ViolationReporter.log_or_raise_violation, ViolationReporter._format_violation, CanaryDeviceState.allocate, CanaryLaunchCapacities.from_args, PlanInput.allocate, PlanInput.fill_from_forward_batch, launch_endpoints_per_forward, invoke_plan
关键源码片段
python/sglang/srt/kv_canary/single_forward_manager/manager.py
定义了 SingleForwardManager 管理单次 forward 的验证/写入计划,是核心运行器之一。
# 文件 : python/sglang/srt/kv_canary/single_forward_manager/manager.py
class _SingleForwardPhase(IntEnum):
IDLE = 0
AFTER_PRE_OUT = 1
AFTER_PRE_MAYBE_IN = 2
AFTER_POST_MAYBE_IN = 3
class SingleForwardManager:
def __init__(self, ..., d2h_stream: torch.cuda.Stream):
self._phase_checker = SimplePhaseChecker(initial_phase=_SingleForwardPhase.IDLE, device=device)
self._output_buffer = PostOpsInsideGraphOutputBuffer.allocate(...)
def pre_ops_outside_graph(self, *, maybe_inaccurate_forward_batch: ForwardBatch) -> None:
self._phase_checker.update(expect_phase=_SingleForwardPhase.IDLE, next_phase=_SingleForwardPhase.AFTER_PRE_OUT, ...)
bs = int(maybe_inaccurate_forward_batch.batch_size)
num_tokens = int(maybe_inaccurate_forward_batch.positions.shape[0])
if bs > self._write_req_capacity:
raise RuntimeError(f"batch_size {bs} exceeds write_req_capacity {self._write_req_capacity}")
if num_tokens > self._write_entry_capacity:
raise RuntimeError(f"num_tokens {num_tokens} exceeds write_entry_capacity {self._write_entry_capacity}")
def pre_ops_maybe_inside_graph(self, forward_batch: ForwardBatch) -> _PreOpsMaybeInsideGraphOutput:
self._phase_checker.update(expect_phase=_SingleForwardPhase.AFTER_PRE_OUT, next_phase=_SingleForwardPhase.AFTER_PRE_MAYBE_IN, ...)
plan_input = self._plan_input # 预分配缓冲区,从 forward_batch 填充
plan_input.fill_from_forward_batch(forward_batch=forward_batch)
# 调用 JIT kernel 生成验证 / 写计划
invoke_plan(plan_input=plan_input, verify_plan=verify_plan, write_plan=write_plan, group=group, ...)
return _PreOpsMaybeInsideGraphOutput(verify_plans=verify_plans, write_plans=write_plans, expected_inputs=expected_inputs)
python/sglang/srt/kv_canary/runner/canary_manager.py
顶层协调器 CanaryManager,负责生命周期、端点构建和 forward 钩子的调度。
# 文件 : python/sglang/srt/kv_canary/runner/canary_manager.py
class CanaryManager:
def __init__(self, *, config, buffer_groups, device, req_to_token_pool, launch_capacities, swa_window_size=0):
self._device_state = CanaryDeviceState.allocate(config=config, device=device, ...)
self._endpoints = tuple(
endpoint
for group in self._buffer_groups
for endpoint in build_endpoints_from_group(group=group, device_state=self._device_state)
)
self._single_forward_managers = (SingleForwardManager(...),)
@contextlib.contextmanager
def with_active_single_forward_manager(self, index: int) -> Iterator[None]:
# 确保不嵌套
assert self._active_single_forward_manager_index is None, "kv-canary: nested with_active_single_forward_manager is forbidden"
self._active_single_forward_manager_index = index
try:
yield
finally:
assert self._active_single_forward_manager_index == index
self._active_single_forward_manager_index = None
def pre_ops_maybe_inside_graph(self, forward_batch: ForwardBatch) -> _PreOpsMaybeInsideGraphOutput:
sfm = self._single_forward_managers[self._active_single_forward_manager_index]
return sfm.pre_ops_maybe_inside_graph(forward_batch=forward_batch)
def with_ops_outside_graph(self, forward_batch: ForwardBatch, ...):
# 调用 pre_ops_outside_graph,然后 yield,然后 post_ops_outside_graph
self._pre_ops_outside_graph(forward_batch=forward_batch)
try:
yield
finally:
self._post_ops_outside_graph(...)
python/sglang/srt/kv_canary/endpoint.py
定义 CanaryEndpoint,封装单个验证/写对的启动逻辑。
# 文件 : python/sglang/srt/kv_canary/endpoint.py
@dataclass(frozen=True, slots=True, kw_only=True)
class CanaryEndpoint:
kernel_kind: CanaryLaunchTag
canary_buf: torch.Tensor
full_to_swa_index_mapping: Optional[torch.Tensor]
slot_run_counter_view: torch.Tensor
kernel_run_counter_view: torch.Tensor
enable_chain_position_assert: torch.Tensor
def launch_per_forward(self, *, verify_plan, write_plan, input_ids, positions, out_cache_loc, ...):
context = self._make_verify_or_write_context(violation_log=violation_log)
launch_canary_verify_kernel(context=context, plan=verify_plan, check_verify_expected_token=...)
# SWA 端点需要索引映射
if self.full_to_swa_index_mapping is not None:
out_cache_loc_for_canary = self.full_to_swa_index_mapping[out_cache_loc]
else:
out_cache_loc_for_canary = out_cache_loc
launch_canary_write_kernel(context=context, plan=write_plan, input_ids=input_ids, positions=positions,
out_cache_loc=out_cache_loc_for_canary, ...)
评论区精华
Reviewer gemini-code-assist[bot] 提出了 9 条评论,主要集中在:
风险与影响
- 风险:
- 回归风险:全新模块,不影响现有功能;但
forward_batch_info.py 的 envs 导入未使用,需确认无副作用。
- 性能风险:每次 forward 增加额外 kernel 启动和 d2h 拷贝开销,但可通过配置禁用。
- 兼容性:当前仅支持 MHA 池,SWA 和 DeepSeek-V4 适配后续添加,非 MHA 模型无法使用。
- 稳定性:RAISE 模式下违规直接抛出
RuntimeError,可能导致服务中断;LOG 模式更安全。
- 影响:
- 用户影响:默认不启用,对普通用户无影响;启用时附加开销,但助于诊断 KV-cache 问题。
- 系统影响:增加 ~3.7k 行代码(含测试),构建时间略有增加。
- 团队影响:提供统一的 KV-cache 验证框架,便于后续扩展和测试。
- 风险标记:核心新模块集成风险, 启用后性能开销, 仅支持MHA池, assert在-O下失效
关联脉络
- PR #26809 Add the KV-canary install API and forward-path wiring: 在前向路径中安装 KV-canary 的 API,依赖当前 PR 的核心组件。
- PR #26810 Add KV-canary SWA + DeepSeek-V4 pool support: 扩展池支持 SWA 和 DeepSeek-V4,基于当前 PR 的核心架构。
- PR #26819 Add the KV-canary perturb modes and PD-disaggregation e2e tests: 添加扰动模式和 PD 拆分端到端测试,进一步验证核心功能。
- PR #26821 Add periodic KV-canary stats logging and kernel-run-counter health check: 增强可观测性,依赖本 PR 的运行时状态。
参与讨论