执行摘要
- 一句话:跳过草稿预填充前的注意力元数据重建
- 推荐动作:该PR值得精读,尤其是对v1推测解码架构和CUDA图捕获流程感兴趣的开发者。
PrefillEagleCudaGraphManager与DecodeEagleCudaGraphManager的拆分设计可复用。由于缺少测试覆盖和潜在的签名不匹配风险,建议合入前补充至少一个端到端测试用例验证不同推测配置。
功能与动机
在MRV2中,草稿预填充回放FULL CUDA图时必须重建注意力元数据,因为捕获时使用的builder状态与回放时的builder不匹配,可能导致crash(见FlashAttention后端中的scheduler metadata buffer问题)。实际上,可以复用目标模型捕获时生成的同一组注意力元数据,从而完全跳过重建步骤。
实现拆解
- 新增CapturedAttentionState类型(
vllm/v1/worker/gpu/cudagraph_utils.py):定义NamedTuple,打包attn_metadata和slot_mappings,作为捕获状态的标准传递单元。
- 改造CudaGraphManager.capture返回值(
cudagraph_utils.py):原方法返回None,现返回dict[BatchExecutionDescriptor, CapturedAttentionState],使得调用方(即目标模型的runner)能够捕获并传递注意力状态。
- 拆分EagleCudaGraphManager(
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py):原单一类拆分为EagleCudaGraphManagerBase(仅保留独立graph pool的公用初始化)、PrefillEagleCudaGraphManager(接受外部传入的注意力状态,用于草稿预填充)和DecodeEagleCudaGraphManager(自行调用prepare_inputs_to_capture构建注意力状态,用于草稿解码)。
- 修改EagleSpeculator(
speculator.py):将capture_model方法改为capture(attn_states),接收来自目标模型runner的注意力状态字典,并传递给PrefillEagleCudaGraphManager。同时移除原propose中重建注意力元数据的冗余逻辑。
- 连通GPUModelRunner(
model_runner.py):在capture_model中捕获目标模型的注意力状态后,直接调用self.speculator.capture(captured_attn_states),完成状态传递。
该PR不涉及测试、配置或部署配套变更。
关键文件:
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py(模块 CUDA图;类别 source;类型 core-logic;符号 EagleCudaGraphManager, EagleCudaGraphManagerBase, PrefillEagleCudaGraphManager, capture): 核心重构文件:将EagleCudaGraphManager拆分为PrefillEagleCudaGraphManager和DecodeEagleCudaGraphManager,实现复用vs自建注意力的两种模式
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py(模块 推测解码;类别 source;类型 core-logic;符号 capture_model, capture): 修改了EagleSpeculator的capture方法签名,接收注意力状态并消除propose中的重建逻辑
vllm/v1/worker/gpu/cudagraph_utils.py(模块 CUDA图;类别 source;类型 core-logic;符号 CapturedAttentionState): 引入CapturedAttentionState类型并修改CudaGraphManager.capture的返回值接口
vllm/v1/worker/gpu/model_runner.py(模块 模型运行器;类别 source;类型 data-contract): 将捕获到的注意力状态传递给speculator,是数据流连通的关键一环
关键符号:EagleCudaGraphManagerBase.capture, PrefillEagleCudaGraphManager.capture, DecodeEagleCudaGraphManager.capture, EagleSpeculator.capture, GPUModelRunner.capture_model
关键源码片段
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
核心重构文件:将EagleCudaGraphManager拆分为PrefillEagleCudaGraphManager和DecodeEagleCudaGraphManager,实现复用vs自建注意力的两种模式
class PrefillEagleCudaGraphManager(EagleCudaGraphManagerBase):
"""Eagle CudaGraphManager for prefill,使用目标模型捕获时预先构建的注意力状态"""
def capture(
self,
forward_fn: Callable,
full_cg_attn_states: dict[BatchExecutionDescriptor, CapturedAttentionState],
progress_bar_desc: str = "Capturing CUDA graphs",
) -> None:
# 根据描述符获取已由目标模型捕获的注意力状态,避免重新构建
def create_forward_fn(
desc: BatchExecutionDescriptor,
) -> tuple[Callable[[CUDAGraphMode], None], CapturedAttentionState]:
num_tokens = desc.num_tokens
num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
num_tokens_across_dp = (
torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
if self.dp_size > 1
else None
)
# 直接使用传进来的注意力状态,不调用 prepare_inputs_to_capture
attn_state = full_cg_attn_states[desc]
attn_metadata, slot_mappings = attn_state
fwd = lambda cg_mode: forward_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
cg_mode,
)
return fwd, attn_state
super().capture(create_forward_fn, progress_bar_desc)
class DecodeEagleCudaGraphManager(EagleCudaGraphManagerBase):
"""Eagle CudaGraphManager for decode draft generation,自己构建注意力元数据"""
def capture(
self,
forward_fn: Callable,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
progress_bar_desc: str = "Capturing CUDA graphs",
) -> None:
# 与传统流程一致,调用 prepare_inputs_to_capture 构建自己的注意力状态
def create_forward_fn(
desc: BatchExecutionDescriptor,
) -> tuple[Callable[[CUDAGraphMode], None], CapturedAttentionState]:
num_tokens = desc.num_tokens
num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
num_tokens_across_dp = (
torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
if self.dp_size > 1
else None
)
attn_state = prepare_inputs_to_capture(
num_reqs,
num_tokens,
model_state,
input_buffers,
block_tables,
attn_groups,
kv_cache_config,
)
attn_metadata, slot_mappings = attn_state
fwd = lambda cg_mode: forward_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
cg_mode,
)
return fwd, attn_state
super().capture(create_forward_fn, progress_bar_desc)
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
修改了EagleSpeculator的capture方法签名,接收注意力状态并消除propose中的重建逻辑
```python
def capture(
self,
attn_states: dict[BatchExecutionDescriptor, CapturedAttentionState],
) -> None:
"""
捕获草稿模型的CUDA图。
接收目标模型 runner 传来的注意力状态字典,
传递给预填充管理器,使其能复用目标模型的注意力元数据。
"""
logger.info("Capturing model for Eagle speculator...")
# 重置索引避免 dummy run 中的过期值导致越界
self.num_sched_tokens.fill_(0)
self.num_computed_tokens.fill_(0)
self.num_seqs.fill_(0)
# 预填充管理器使用接收到的注意力状态(来自目标模型捕获)
assert self.prefill_cudagraph_manager is not None
self.prefill_cudagraph_manager.capture(
self.prefill,
attn_states,
progress_bar_desc="Capturing eagle prefill CUDA graphs",
)
# 解码管理器仍自行构建注意力状态(需要自己的模型状态、block 表等)
assert self.decode_cudagraph_manager is not None
self.decode_cudagraph_manager.capture(
self.decode,
self.model_state,
self.input_buffers,
self.block_tables,
self.attn_groups,
self.kv_cache_config,
progress_bar_desc="Capturing eagle decode CUDA graphs",
)
``` (原propose中删除的rebuil逻辑不再展示)
评论区精华
仅有gemini-code-assist的机器人评论总结了变更要点,以及WoosukKwon的LGTM批准,未发现实质性的设计争议或未解决疑虑。
风险与影响
关联脉络
参与讨论