执行摘要
- 一句话:修复 V1 speculative decoding 中 draft_probs 未传递使 probabilistic rejection 失效
- 推荐动作:值得精读。本 PR 虽然改动量中等,但修复了一个重要的正确性问题,展示了 speculative decoding 中 draft_probs 的完整生命周期:从 proposer 采样时捕获,跨模块缓存,到 GPUModelRunner 按请求重新排列,最终传递给 rejection sampler。设计模式清晰,配套测试完善。尤其推荐关注 _get_spec_decode_draft_probs 中的请求顺序对齐逻辑。
功能与动机
关联 Issue #40149 指出,在 V1 中使用 draft_model 与 probabilistic 拒绝采样时,GPUModelRunner._sample() 传递给 RejectionSampler 的 draft_probs 参数始终为 None,导致 draft_prob 默认化为 1,偏离了 Leviathan et al. (2022) 定义的分布匹配逻辑(p(x)/q(x) 比率)。本 PR 旨在将 draft 模型的概率分布实际传入 rejection 流程,以恢复正确的 probabilistic 拒绝采样行为。
实现拆解
- Proposer 侧采样与概率捕获:在
vllm/v1/spec_decode/llm_base_proposer.py 中新增 _sample_from_logits 和 _sample_draft_tokens 方法。当配置为 rejection_sample_method="standard" 且 draft_sample_method="probabilistic" 时,使用 compute_probs_and_sample_next_token 进行概率采样并返回分布;否则回退到贪婪采样。每次 propose() 调用末尾将 _last_draft_probs 缓存起来。
- Runner 侧缓存与重新排列:在
vllm/v1/worker/gpu_model_runner.py 中新增 _draft_probs 和 _draft_prob_req_ids 属性,在 propose_draft_token_ids() 中从 proposer 的 take_last_draft_probs() 获取概率并关联当前请求 ID;在 sample_tokens() 开始时重置。
- 按请求顺序提取:
_get_spec_decode_draft_probs() 方法根据当前 input_batch.req_ids 和 num_draft_tokens,从缓存中按请求顺序提取对应的概率行和切片,处理缺失请求和零 draft 数量。
- 传递有效概率:
_sample() 中调用 _get_spec_decode_draft_probs() 并将返回值传入 rejection_sampler,替换原来的 None。
- 配置重命名:
vllm/config/speculative.py 中将 DraftSampleMethod 从 'greedy' | 'gumbel' 改为 'greedy' | 'probabilistic',并更新文档字符串。
- 测试配套:新增三个测试用例:
test_propose_stores_probabilistic_draft_probs(验证 proposer 缓存概率)、test_sample_passes_reordered_draft_probs_to_rejection_sampler(验证 runner 重新排列并传递)、以及配置校验测试。
关键文件:
vllm/v1/spec_decode/llm_base_proposer.py(模块 推测解码提议器;类别 source;类型 core-logic;符号 _sample_from_logits, _sample_draft_tokens, take_last_draft_probs): 核心提议器,新增 _sample_from_logits 和 _sample_draft_tokens 以在 probabilistic 模式下捕获 draft probabilities,并新增 take_last_draft_probs 供 GPUModelRunner 获取。
vllm/v1/worker/gpu_model_runner.py(模块 模型运行器;类别 source;类型 data-contract;符号 _get_spec_decode_draft_probs): 模型运行器,新增缓存 _draft_probs 和 _draft_prob_req_ids,新增 _get_spec_decode_draft_probs 方法按请求顺序重新排列 draft 概率,并替换以前传递的 None。
tests/v1/spec_decode/test_eagle.py(模块 Eagle 测试;类别 test;类型 test-coverage;符号 test_propose_stores_probabilistic_draft_probs, fake_compute_probs): 新增 test_propose_stores_probabilistic_draft_probs 验证 proposer 在 probabilistic 模式下正确缓存概率,并扩展了 _create_proposer 以支持新的配置参数。
tests/v1/worker/test_gpu_model_runner.py(模块 模型运行器测试;类别 test;类型 test-coverage;符号 test_sample_passes_reordered_draft_probs_to_rejection_sampler): 新增 test_sample_passes_reordered_draft_probs_to_rejection_sampler 验证 GPUModelRunner 正确重新排序 draft_probs 并传递给 rejection sampler。
tests/test_config.py(模块 配置测试;类别 test;类型 test-coverage;符号 test_draft_sample_method_probabilistic_is_accepted, test_draft_sample_method_gumbel_is_rejected): 新增 test_draft_sample_method_probabilistic_is_accepted 和 test_draft_sample_method_gumbel_is_rejected 测试验证配置接受 'probabilistic' 并拒绝 'gumbel'。
vllm/config/speculative.py(模块 推测配置;类别 source;类型 core-logic;符号 DraftSampleMethod): 核心配置变更,将 DraftSampleMethod 从 'greedy' | 'gumbel' 改为 'greedy' | 'probabilistic',并更新文档字符串。
关键符号:_sample_from_logits, _sample_draft_tokens, take_last_draft_probs, _get_spec_decode_draft_probs
关键源码片段
vllm/v1/spec_decode/llm_base_proposer.py
核心提议器,新增 _sample_from_logits 和 _sample_draft_tokens 以在 probabilistic 模式下捕获 draft probabilities,并新增 take_last_draft_probs 供 GPUModelRunner 获取。
# llm_base_proposer.py ( 新增字段和方法 )
# 在 __init__ 中启用条件
self._enable_probabilistic_draft_probs = (
self.speculative_config.rejection_sample_method == "standard"
and self.speculative_config.draft_sample_method == "probabilistic"
)
self._last_draft_probs: torch.Tensor | None = None
def _sample_from_logits(self, logits, sampling_metadata):
"""根据配置和采样元数据决定是否使用概率采样。
Returns: (采样 token ids, 概率分布或 None)
"""
if not self._enable_probabilistic_draft_probs:
return logits.argmax(dim=-1), None # 贪婪模式
if sampling_metadata.all_greedy:
return logits.argmax(dim=-1), None # 全部贪婪则无需概率
# 使用 compute_probs_and_sample_next_token 进行概率采样
return compute_probs_and_sample_next_token(logits, sampling_metadata)
def _sample_draft_tokens(self, hidden_states, sampling_metadata):
"""对 draft 模型隐藏状态进行采样。
在 probabilistic 模式下计算 logits 后调用 _sample_from_logits。
"""
if not self._enable_probabilistic_draft_probs or sampling_metadata.all_greedy:
return self._greedy_sample(hidden_states), None
logits = self.model.compute_logits(hidden_states)
return self._sample_from_logits(logits, sampling_metadata)
def take_last_draft_probs(self):
return self._last_draft_probs
# 在 propose() 中每次调用前重置
def propose(self, ...) -> torch.Tensor:
self._last_draft_probs = None
...
# 调用 _sample_draft_tokens 获得 draft_probs
draft_token_ids, draft_probs = self._sample_draft_tokens(sample_hidden_states, sampling_metadata)
if draft_probs is not None:
self._last_draft_probs = draft_probs.view(-1, self.num_speculative_tokens, draft_probs.shape[-1]).contiguous()
return draft_token_ids.view(-1, self.num_speculative_tokens)
vllm/v1/worker/gpu_model_runner.py
模型运行器,新增缓存 _draft_probs 和 _draft_prob_req_ids,新增 _get_spec_decode_draft_probs 方法按请求顺序重新排列 draft 概率,并替换以前传递的 None。
# gpu_model_runner.py ( 新增缓存和提取方法 )
# 在 __init__ 中新增属性
self._draft_probs: torch.Tensor | None = None
self._draft_prob_req_ids: list[str] | None = None
# 在 propose_draft_token_ids 和 sample_tokens 的开始重置
def propose_draft_token_ids(self, scheduler_output):
self._draft_probs = None
self._draft_prob_req_ids = None
# ... 原有逻辑,在获得 draft_probs 后缓存
draft_probs = proposer.take_last_draft_probs()
if draft_probs is not None:
self._draft_probs = draft_probs
self._draft_prob_req_ids = list(self.input_batch.req_ids)
# 新方法:将缓存的 draft_probs 按当前批次的请求顺序和 draft 数量重新排列
def _get_spec_decode_draft_probs(self, spec_decode_metadata):
if self._draft_probs is None or self._draft_prob_req_ids is None:
return None
# 构建 请求 ID -> 缓存行索引 的映射
row_by_req_id = {req_id: idx for idx, req_id in enumerate(self._draft_prob_req_ids)}
draft_probs_rows = []
# 遍历当前批次的请求和对应的 draft 数量
for req_id, num_draft in zip(self.input_batch.req_ids, spec_decode_metadata.num_draft_tokens):
if num_draft == 0:
continue # 跳过无 draft 的请求,不影响排列顺序
row_idx = row_by_req_id.get(req_id)
if row_idx is None:
# 若请求未在缓存中找到(如新请求),回退到 None
logger.warning("Missing cached draft probabilities for request %s; "
"falling back to legacy speculative rejection behavior.", req_id)
return None
draft_probs_rows.append(self._draft_probs[row_idx, :num_draft])
if not draft_probs_rows:
return None
return torch.cat(draft_probs_rows, dim=0).contiguous()
# 在 _sample 中使用
def _sample(self, logits, spec_decode_metadata):
...
draft_probs = self._get_spec_decode_draft_probs(spec_decode_metadata)
sampler_output = self.rejection_sampler(
spec_decode_metadata,
draft_probs, # 之前是 None
logits,
sampling_metadata,
)
return sampler_output
评论区精华
主要讨论包括:
- benchislett 建议将 draft_sample_method 从 'gumbel' 重命名为 'probabilistic',因为 V1 使用的是常规概率采样而非 Gumbel 噪声,bedeks 采纳并重命名。
- benchislett 对 _get_spec_decode_draft_probs 中 num_draft == 0 的情况处理提出疑问,bedeks 解释跳过零长度的条目不会打乱顺序,且有回归测试覆盖(请求 [a,b,c] 分别对应 draft 数量 [2,0,1] 时,输出应为 [a[:2], c[:1]])。
- benchislett 指出 draft_probs_list 初始化应为 None 而非 [],bedeks 修正。
- draft_sample_method 命名调整 (design): bedeks 采纳并重命名为 probabilistic,同时更新文档字符串。
- 处理请求无 draft token 的边界情况 (correctness): bedeks 解释跳过零长度条目不会影响顺序,且有回归测试验证请求顺序 [a,b,c] 对应 draft 数量 [2,0,1] 时输出正确。
- draft_probs_list 初始化值 (correctness): bedeks 确认并修正为 None。
风险与影响
- 风险:
- 仅当 rejection_sample_method='standard' 且 draft_sample_method='probabilistic' 时才启用新路径,默认行为不受影响。
- 新增的 _last_draft_probs 等缓存若不及时清理,可能导致内存占用缓慢增长;代码在 propose_draft_token_ids 和 sample_tokens 中都设置了显式重置为 None。
- 请求顺序重新映射逻辑假定 input_batch.req_ids 与 spec_decode_metadata.num_draft_tokens 顺序一致,若不符可能导致概率错位;测试验证了正确性。
- 若请求缺失缓存概率(如新的请求进入但概率尚未填充),代码会回退到 None(即 fallback 行为),不会崩溃,但可能短暂回归旧行为。
- 影响:影响范围限于 V1 speculative decoding 中采用 probabilistic rejection 的用户。该修复使 acceptance rate 从约 0.22 提升至 0.45,显著提高推理性能。对使用默认 greedy 采样的用户无影响。新增测试覆盖确保了核心路径的可靠性。配置变更要求用户将 'gumbel' 改为 'probabilistic',但之前该设置仅适用于 MRV2,V1 本应使用 'probabilistic',因此这是预期修复。
- 风险标记:核心路径变更, 数据流顺序依赖, 缓存生命周期
关联脉络
- PR #40149 [Feature]: Speculative Decoding using draft_model does not use draft_probs: 此 issue 报告了 draft_probs 未传递的问题,是本次 PR 的直接动机。
参与讨论