Prhub

#40269 [Bugfix][Spec Decode] Wire draft_probs into probabilistic draft_model rejection

原始 PR 作者 bedeks 合并时间 2026-05-14 09:04 文件变更 7 提交数 6 评论 13 代码增减 +264 / -13

执行摘要

修复 V1 speculative decoding 中 draft_probs 未传递使 probabilistic rejection 失效

关联 Issue #40149 指出,在 V1 中使用 draft_model 与 probabilistic 拒绝采样时,GPUModelRunner._sample() 传递给 RejectionSampler 的 draft_probs 参数始终为 None,导致 draft_prob 默认化为 1,偏离了 Leviathan et al. (2022) 定义的分布匹配逻辑(p(x)/q(x) 比率)。本 PR 旨在将 draft 模型的概率分布实际传入 rejection 流程,以恢复正确的 probabilistic 拒绝采样行为。

值得精读。本 PR 虽然改动量中等,但修复了一个重要的正确性问题,展示了 speculative decoding 中 draft_probs 的完整生命周期:从 proposer 采样时捕获,跨模块缓存,到 GPUModelRunner 按请求重新排列,最终传递给 rejection sampler。设计模式清晰,配套测试完善。尤其推荐关注 _get_spec_decode_draft_probs 中的请求顺序对齐逻辑。

讨论亮点

主要讨论包括:

  1. benchislett 建议将 draft_sample_method 从 'gumbel' 重命名为 'probabilistic',因为 V1 使用的是常规概率采样而非 Gumbel 噪声,bedeks 采纳并重命名。
  2. benchislett 对 _get_spec_decode_draft_probs 中 num_draft == 0 的情况处理提出疑问,bedeks 解释跳过零长度的条目不会打乱顺序,且有回归测试覆盖(请求 [a,b,c] 分别对应 draft 数量 [2,0,1] 时,输出应为 [a[:2], c[:1]])。
  3. benchislett 指出 draft_probs_list 初始化应为 None 而非 [],bedeks 修正。

实现拆解

  1. Proposer 侧采样与概率捕获:在 vllm/v1/spec_decode/llm_base_proposer.py 中新增 _sample_from_logits_sample_draft_tokens 方法。当配置为 rejection_sample_method="standard"draft_sample_method="probabilistic" 时,使用 compute_probs_and_sample_next_token 进行概率采样并返回分布;否则回退到贪婪采样。每次 propose() 调用末尾将 _last_draft_probs 缓存起来。
  2. Runner 侧缓存与重新排列:在 vllm/v1/worker/gpu_model_runner.py 中新增 _draft_probs_draft_prob_req_ids 属性,在 propose_draft_token_ids() 中从 proposer 的 take_last_draft_probs() 获取概率并关联当前请求 ID;在 sample_tokens() 开始时重置。
  3. 按请求顺序提取_get_spec_decode_draft_probs() 方法根据当前 input_batch.req_idsnum_draft_tokens,从缓存中按请求顺序提取对应的概率行和切片,处理缺失请求和零 draft 数量。
  4. 传递有效概率_sample() 中调用 _get_spec_decode_draft_probs() 并将返回值传入 rejection_sampler,替换原来的 None
  5. 配置重命名vllm/config/speculative.py 中将 DraftSampleMethod'greedy' | 'gumbel' 改为 'greedy' | 'probabilistic',并更新文档字符串。
  6. 测试配套:新增三个测试用例:test_propose_stores_probabilistic_draft_probs(验证 proposer 缓存概率)、test_sample_passes_reordered_draft_probs_to_rejection_sampler(验证 runner 重新排列并传递)、以及配置校验测试。
文件 模块 状态 重要度
vllm/v1/spec_decode/llm_base_proposer.py 推测解码提议器 modified 7.88
vllm/v1/worker/gpu_model_runner.py 模型运行器 modified 7.39
tests/v1/spec_decode/test_eagle.py Eagle 测试 modified 6.47
tests/v1/worker/test_gpu_model_runner.py 模型运行器测试 modified 5.69
tests/test_config.py 配置测试 modified 5.09
vllm/config/speculative.py 推测配置 modified 5.34

关键符号

_sample_from_logits _sample_draft_tokens take_last_draft_probs _get_spec_decode_draft_probs

关键源码片段

vllm/v1/spec_decode/llm_base_proposer.py core-logic

核心提议器,新增 _sample_from_logits 和 _sample_draft_tokens 以在 probabilistic 模式下捕获 draft probabilities,并新增 take_last_draft_probs 供 GPUModelRunner 获取。

# llm_base_proposer.py ( 新增字段和方法 )
# 在 __init__ 中启用条件
self._enable_probabilistic_draft_probs = (
    self.speculative_config.rejection_sample_method == "standard"
    and self.speculative_config.draft_sample_method == "probabilistic"
)
self._last_draft_probs: torch.Tensor | None = Nonedef _sample_from_logits(self, logits, sampling_metadata):
    """根据配置和采样元数据决定是否使用概率采样。
    Returns: (采样 token ids, 概率分布或 None)
    """
    if not self._enable_probabilistic_draft_probs:
        return logits.argmax(dim=-1), None # 贪婪模式
    if sampling_metadata.all_greedy:
        return logits.argmax(dim=-1), None # 全部贪婪则无需概率
    # 使用 compute_probs_and_sample_next_token 进行概率采样
    return compute_probs_and_sample_next_token(logits, sampling_metadata)def _sample_draft_tokens(self, hidden_states, sampling_metadata):
    """对 draft 模型隐藏状态进行采样。
    在 probabilistic 模式下计算 logits 后调用 _sample_from_logits。
    """
    if not self._enable_probabilistic_draft_probs or sampling_metadata.all_greedy:
        return self._greedy_sample(hidden_states), None
    logits = self.model.compute_logits(hidden_states)
    return self._sample_from_logits(logits, sampling_metadata)def take_last_draft_probs(self):
    return self._last_draft_probs# 在 propose() 中每次调用前重置
def propose(self, ...) -> torch.Tensor:
    self._last_draft_probs = None
    ...
    # 调用 _sample_draft_tokens 获得 draft_probs
    draft_token_ids, draft_probs = self._sample_draft_tokens(sample_hidden_states, sampling_metadata)
    if draft_probs is not None:
        self._last_draft_probs = draft_probs.view(-1, self.num_speculative_tokens, draft_probs.shape[-1]).contiguous()
    return draft_token_ids.view(-1, self.num_speculative_tokens)
vllm/v1/worker/gpu_model_runner.py data-contract

模型运行器,新增缓存 _draft_probs 和 _draft_prob_req_ids,新增 _get_spec_decode_draft_probs 方法按请求顺序重新排列 draft 概率,并替换以前传递的 None。

# gpu_model_runner.py ( 新增缓存和提取方法 )
# 在 __init__ 中新增属性
self._draft_probs: torch.Tensor | None = None
self._draft_prob_req_ids: list[str] | None = None# 在 propose_draft_token_ids 和 sample_tokens 的开始重置
def propose_draft_token_ids(self, scheduler_output):
    self._draft_probs = None
    self._draft_prob_req_ids = None
    # ... 原有逻辑,在获得 draft_probs 后缓存
    draft_probs = proposer.take_last_draft_probs()
    if draft_probs is not None:
        self._draft_probs = draft_probs
        self._draft_prob_req_ids = list(self.input_batch.req_ids)# 新方法:将缓存的 draft_probs 按当前批次的请求顺序和 draft 数量重新排列
def _get_spec_decode_draft_probs(self, spec_decode_metadata):
    if self._draft_probs is None or self._draft_prob_req_ids is None:
        return None
    # 构建 请求 ID -> 缓存行索引 的映射
    row_by_req_id = {req_id: idx for idx, req_id in enumerate(self._draft_prob_req_ids)}
    draft_probs_rows = []
    # 遍历当前批次的请求和对应的 draft 数量
    for req_id, num_draft in zip(self.input_batch.req_ids, spec_decode_metadata.num_draft_tokens):
        if num_draft == 0:
            continue # 跳过无 draft 的请求,不影响排列顺序
        row_idx = row_by_req_id.get(req_id)
        if row_idx is None:
            # 若请求未在缓存中找到(如新请求),回退到 None
            logger.warning("Missing cached draft probabilities for request %s; "
                           "falling back to legacy speculative rejection behavior.", req_id)
            return None
        draft_probs_rows.append(self._draft_probs[row_idx, :num_draft])
    if not draft_probs_rows:
        return None
    return torch.cat(draft_probs_rows, dim=0).contiguous()# 在 _sample 中使用
def _sample(self, logits, spec_decode_metadata):
    ...
    draft_probs = self._get_spec_decode_draft_probs(spec_decode_metadata)
    sampler_output = self.rejection_sampler(
        spec_decode_metadata,
        draft_probs, # 之前是 None
        logits,
        sampling_metadata,
    )
    return sampler_output

评论区精华

draft_sample_method 命名调整 设计

benchislett 指出 'gumbel' 是 MRV2 的特定实现,V1 实际是概率采样,建议重命名为 'probabilistic' 或 'sampled'

结论:bedeks 采纳并重命名为 probabilistic,同时更新文档字符串。 · 已解决

处理请求无 draft token 的边界情况 正确性

benchislett 对 _get_spec_decode_draft_probs 中 num_draft == 0 的跳过逻辑提出疑问,担心打乱顺序。

结论:bedeks 解释跳过零长度条目不会影响顺序,且有回归测试验证请求顺序 [a,b,c] 对应 draft 数量 [2,0,1] 时输出正确。 · 已解决

draft_probs_list 初始化值 正确性

benchislett 指出 draft_probs_list = [] 然后检查 is None 可能不合适,应设为 None。

结论:bedeks 确认并修正为 None。 · 已解决

风险与影响

  1. 仅当 rejection_sample_method='standard' 且 draft_sample_method='probabilistic' 时才启用新路径,默认行为不受影响。
  2. 新增的 _last_draft_probs 等缓存若不及时清理,可能导致内存占用缓慢增长;代码在 propose_draft_token_ids 和 sample_tokens 中都设置了显式重置为 None。
  3. 请求顺序重新映射逻辑假定 input_batch.req_ids 与 spec_decode_metadata.num_draft_tokens 顺序一致,若不符可能导致概率错位;测试验证了正确性。
  4. 若请求缺失缓存概率(如新的请求进入但概率尚未填充),代码会回退到 None(即 fallback 行为),不会崩溃,但可能短暂回归旧行为。

影响范围限于 V1 speculative decoding 中采用 probabilistic rejection 的用户。该修复使 acceptance rate 从约 0.22 提升至 0.45,显著提高推理性能。对使用默认 greedy 采样的用户无影响。新增测试覆盖确保了核心路径的可靠性。配置变更要求用户将 'gumbel' 改为 'probabilistic',但之前该设置仅适用于 MRV2,V1 本应使用 'probabilistic',因此这是预期修复。

核心路径变更 数据流顺序依赖 缓存生命周期

关联 Issue

#40149 [Feature]: Speculative Decoding using draft_model does not use draft_probs

完整报告

参与讨论