Prhub

#26808 Add the KV-canary core: data layer, MHA KV-pool patcher, and per-forward runner

原始 PR 作者 fzyzcjy 合并时间 2026-05-31 09:54 文件变更 46 提交数 1 评论 9 代码增减 +3756 / -0

执行摘要

添加 KV-canary 核心:数据层、KV 池修补器和前向运行器

PR 描述指出这是一个自包含的金丝雀核心,后续将分 PR 添加更多池支持和 E2E 测试。动机是为 sglang 提供一种可插拔的运行时验证机制,用于检测 KV-cache 中的数据损坏。

此 PR 是 KV-canary 系列的基础,建议精读理解设计模式:状态管理、池修补、前向钩子集成。关注 assert 替换为 Exception 的讨论,这是生产代码的重要稳健性考量。

讨论亮点

Reviewer gemini-code-assist[bot] 提出了 9 条评论,主要集中在:

  • assert 替换为 RuntimeError:多处 assert 用于运行时验证(如嵌套上下文检查、状态一致性),建议改为显式异常以避免被 -O 优化绕过。
  • 异常安全性FutureTensors.step 中若 postprocess_on_host 抛出异常,_future 不会清空,建议使用 try...finally 确保清理。
  • 空输入早返回launch_endpoints_per_forward 中若 positions 为空,应提前返回避免内核启动。
  • 未使用导入canary_manager.pyforward_batch_info.py 中存在未使用的 envs 导入,建议移除。
    这些评论均未得到作者明确回复,但 PR 已合并。

实现拆解

  1. 数据层与状态管理:新增 configstatebuffer_groupcapacities 等文件,定义金丝雀的配置和运行时状态,包括 CanaryDeviceStateViolationLog
  2. KV-pool 修补器:在 pool_patcher/buf_info_splice.py 中修补 KV-pool 的缓冲区信息方法,插入金丝雀缓冲区视图。
  3. 端点与运行器:实现 CanaryEndpoint(代表一个验证-写对)和 SingleForwardManager(管理单次 forward 的验证和写入计划),以及顶层的 CanaryManager 协调多个 forward 管理器。
  4. 验证与违规报告ViolationReporterViolationManager 负责从 GPU 读取违规环形缓冲区并在违规时记录或抛出异常。ViolationReporter 支持 LOGRAISE 两种模式。
  5. 集成与测试:修改 forward_batch_info.py 以导入 envs(可能用于环境变量),新增 18+ 个测试文件覆盖单元和集成测试,验证金丝雀管线的正确性。
文件 模块 状态 重要度
python/sglang/srt/kv_canary/single_forward_manager/manager.py 前向管理器 added 9.24
python/sglang/srt/kv_canary/runner/canary_manager.py 金丝雀运行器 added 9.13
python/sglang/srt/kv_canary/endpoint.py 端点 added 9.04
python/sglang/srt/kv_canary/runner/violation_reporter.py 违规报告器 added 8.83
python/sglang/srt/kv_canary/state.py 状态管理 added 8.43
python/sglang/srt/kv_canary/capacities.py 容量配置 added 8.25

关键符号

SingleForwardManager.__init__ SingleForwardManager.pre_ops_outside_graph SingleForwardManager.pre_ops_maybe_inside_graph CanaryManager.__init__ CanaryManager.with_active_single_forward_manager CanaryManager.pre_ops_maybe_inside_graph CanaryEndpoint.launch_per_forward CanaryEndpoint._make_verify_or_write_context build_endpoints_from_group ViolationReporter.log_or_raise_violation ViolationReporter._format_violation CanaryDeviceState.allocate CanaryLaunchCapacities.from_args PlanInput.allocate PlanInput.fill_from_forward_batch launch_endpoints_per_forward invoke_plan

关键源码片段

python/sglang/srt/kv_canary/single_forward_manager/manager.py core-logic

定义了 SingleForwardManager 管理单次 forward 的验证 / 写入计划,是核心运行器之一。

# 文件 : python/sglang/srt/kv_canary/single_forward_manager/manager.py
class _SingleForwardPhase(IntEnum):
    IDLE = 0
    AFTER_PRE_OUT = 1
    AFTER_PRE_MAYBE_IN = 2
    AFTER_POST_MAYBE_IN = 3class SingleForwardManager:
    def __init__(self, ..., d2h_stream: torch.cuda.Stream):
        self._phase_checker = SimplePhaseChecker(initial_phase=_SingleForwardPhase.IDLE, device=device)
        self._output_buffer = PostOpsInsideGraphOutputBuffer.allocate(...)
​
    def pre_ops_outside_graph(self, *, maybe_inaccurate_forward_batch: ForwardBatch) -> None:
        self._phase_checker.update(expect_phase=_SingleForwardPhase.IDLE, next_phase=_SingleForwardPhase.AFTER_PRE_OUT, ...)
        bs = int(maybe_inaccurate_forward_batch.batch_size)
        num_tokens = int(maybe_inaccurate_forward_batch.positions.shape[0])
        if bs > self._write_req_capacity:
            raise RuntimeError(f"batch_size {bs} exceeds write_req_capacity {self._write_req_capacity}")
        if num_tokens > self._write_entry_capacity:
            raise RuntimeError(f"num_tokens {num_tokens} exceeds write_entry_capacity {self._write_entry_capacity}")
​
    def pre_ops_maybe_inside_graph(self, forward_batch: ForwardBatch) -> _PreOpsMaybeInsideGraphOutput:
        self._phase_checker.update(expect_phase=_SingleForwardPhase.AFTER_PRE_OUT, next_phase=_SingleForwardPhase.AFTER_PRE_MAYBE_IN, ...)
        plan_input = self._plan_input # 预分配缓冲区,从 forward_batch 填充
        plan_input.fill_from_forward_batch(forward_batch=forward_batch)
        # 调用 JIT kernel 生成验证 / 写计划
        invoke_plan(plan_input=plan_input, verify_plan=verify_plan, write_plan=write_plan, group=group, ...)
        return _PreOpsMaybeInsideGraphOutput(verify_plans=verify_plans, write_plans=write_plans, expected_inputs=expected_inputs)
python/sglang/srt/kv_canary/runner/canary_manager.py core-logic

顶层协调器 CanaryManager,负责生命周期、端点构建和 forward 钩子的调度。

# 文件 : python/sglang/srt/kv_canary/runner/canary_manager.py
class CanaryManager:
    def __init__(self, *, config, buffer_groups, device, req_to_token_pool, launch_capacities, swa_window_size=0):
        self._device_state = CanaryDeviceState.allocate(config=config, device=device, ...)
        self._endpoints = tuple(
            endpoint
            for group in self._buffer_groups
            for endpoint in build_endpoints_from_group(group=group, device_state=self._device_state)
        )
        self._single_forward_managers = (SingleForwardManager(...),)
​
    @contextlib.contextmanager
    def with_active_single_forward_manager(self, index: int) -> Iterator[None]:
        # 确保不嵌套
        assert self._active_single_forward_manager_index is None, "kv-canary: nested with_active_single_forward_manager is forbidden"
        self._active_single_forward_manager_index = index
        try:
            yield
        finally:
            assert self._active_single_forward_manager_index == index
            self._active_single_forward_manager_index = None
​
    def pre_ops_maybe_inside_graph(self, forward_batch: ForwardBatch) -> _PreOpsMaybeInsideGraphOutput:
        sfm = self._single_forward_managers[self._active_single_forward_manager_index]
        return sfm.pre_ops_maybe_inside_graph(forward_batch=forward_batch)
​
    def with_ops_outside_graph(self, forward_batch: ForwardBatch, ...):
        # 调用 pre_ops_outside_graph,然后 yield,然后 post_ops_outside_graph
        self._pre_ops_outside_graph(forward_batch=forward_batch)
        try:
            yield
        finally:
            self._post_ops_outside_graph(...)
python/sglang/srt/kv_canary/endpoint.py core-logic

定义 CanaryEndpoint,封装单个验证 / 写对的启动逻辑。

# 文件 : python/sglang/srt/kv_canary/endpoint.py
@dataclass(frozen=True, slots=True, kw_only=True)
class CanaryEndpoint:
    kernel_kind: CanaryLaunchTag
    canary_buf: torch.Tensor
    full_to_swa_index_mapping: Optional[torch.Tensor]
    slot_run_counter_view: torch.Tensor
    kernel_run_counter_view: torch.Tensor
    enable_chain_position_assert: torch.Tensor
​
    def launch_per_forward(self, *, verify_plan, write_plan, input_ids, positions, out_cache_loc, ...):
        context = self._make_verify_or_write_context(violation_log=violation_log)
        launch_canary_verify_kernel(context=context, plan=verify_plan, check_verify_expected_token=...)
        # SWA 端点需要索引映射
        if self.full_to_swa_index_mapping is not None:
            out_cache_loc_for_canary = self.full_to_swa_index_mapping[out_cache_loc]
        else:
            out_cache_loc_for_canary = out_cache_loc
        launch_canary_write_kernel(context=context, plan=write_plan, input_ids=input_ids, positions=positions,
                                   out_cache_loc=out_cache_loc_for_canary, ...)

评论区精华

assert 应替换为显式 RuntimeError 正确性

reviewer 指出 canary_manager.py 中多处 assert(如 nested 检查、退出时索引检查)在 Python -O 模式下会被跳过,导致状态损坏。建议改为 if raise RuntimeError。

结论:评论未得到作者回应,PR 已合并但代码仍保留 assert。 · unresolved

FutureTensors 异常安全 正确性

reviewer 指出 future_tensor.py 的 step 方法中,若 postprocess_on_host 抛出异常,self._future 不会被清空,建议用 try...finally 确保即使异常也清除。

结论:评论未回应,代码未修改。 · unresolved

空输入早期返回 性能

reviewer 建议在 kernel_launcher.py 的 launch_endpoints_per_forward 中,如果 num_tokens == 0,提前返回以避免不必要的内核启动和潜在的空张量错误。

结论:未修改,但可视为优化建议。 · unresolved

未使用的 import envs style

reviewer 指出 canary_manager.py 和 forward_batch_info.py 导入了 envs 但未使用,建议移除。

结论:未被采纳,导入保留。 · unresolved

风险与影响

  • 回归风险:全新模块,不影响现有功能;但 forward_batch_info.pyenvs 导入未使用,需确认无副作用。
  • 性能风险:每次 forward 增加额外 kernel 启动和 d2h 拷贝开销,但可通过配置禁用。
  • 兼容性:当前仅支持 MHA 池,SWA 和 DeepSeek-V4 适配后续添加,非 MHA 模型无法使用。
  • 稳定性:RAISE 模式下违规直接抛出 RuntimeError,可能导致服务中断;LOG 模式更安全。
  • 用户影响:默认不启用,对普通用户无影响;启用时附加开销,但助于诊断 KV-cache 问题。
  • 系统影响:增加 ~3.7k 行代码(含测试),构建时间略有增加。
  • 团队影响:提供统一的 KV-cache 验证框架,便于后续扩展和测试。
核心新模块集成风险 启用后性能开销 仅支持 MHA 池 assert 在 -O 下失效

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论