Prhub

#40651 [Model Runner V2] Fix rejection sampling acceptance rate gap vs MRV1

原始 PR 作者 TheEpicDolphin 合并时间 2026-04-27 10:12 文件变更 4 提交数 1 评论 1 代码增减 +126 / -148

执行摘要

MRV2 启用 one-hot 拒绝采样,修复接受率差距

MRV1 使用 argmax 采样 draft tokens,将 draft 概率分布视为 one-hot,改变了标准拒绝采样的概率测试 u < p(x)/q(x)u < p(x),从而获得更高效的接受率。而 MRV2 默认使用 strict acceptance(只接受完全匹配的 draft tokens),导致接受率较低。MRV2 的 probabilistic 方法虽然使用完整 draft 概率但需要存储 full draft logits,内存开销大。因此需要实现一个中间方案:在 MRV2 中启用与 MRV1 等价的 one-hot 拒绝采样,以修复接受率差距,同时避免额外内存消耗。

建议所有使用 MRV2 推测解码的同学阅读此 PR,了解新的 draft_sample_method 配置及其对接受率和内存的影响。特别值得关注的设计模式是:通过 Triton 编译常量(HAS_DRAFT_LOGITS)在编译期分支内核逻辑,无运行时开销。同时关注后续可能的扩展:允许在不缓存 draft logits 的情况下使用随机采样。

讨论亮点

benchislett 评论:“Need to double-check the math but at a high level it looks good”,表明 reviewer 关注数学正确性,但大致认可方案。

gemini-code-assist[bot] 自动审查指出:当 draft_sample_method'greedy' 时强制使用 argmax 限制了随机采样选项,建议允许即使不缓存 draft logits 也能使用 gumbel 采样(processed_logits_out=None)。这是一个合理的设计权衡:当前实现以内存效率为重,牺牲了随机 drafted 的灵活性。

实现拆解

  1. 配置模块(config/speculative.py):将 RejectionSampleMethod 类型从 'strict' | 'probabilistic' | 'synthetic' 改为 'standard' | 'synthetic',新增 DraftSampleMethod 类型 'greedy' | 'gumbel',并添加 draft_sample_method 配置字段(默认 'greedy')。同时更新了 rejection_sample_method 的默认值为 'standard'

  2. 概率拒绝采样工具(probabilistic_rejection_sampler_utils.py):将核心内核重命名为 _compute_block_stats_kernel,新增编译常量 HAS_DRAFT_LOGITS。当 HAS_DRAFT_LOGITSFalse 时,跳过 draft 统计量的计算,仅计算 target 的 argmax 或统计量。在 _probabilistic_rejection_kernel 中也加入 HAS_DRAFT_LOGITS 条件,避免加载和计算 draft logits。

  3. 拒绝采样器(rejection_sampler.py):删除 _strict_rejection_sample_kernelstrict_rejection_sample 函数。将 __call__ 方法中的 'strict''probabilistic' 分支合并为 'standard' 分支,始终调用 probabilistic_rejection_sample,该函数根据是否传入 draft_logits 自动选择行为(有 draft_logits 时使用完整概率,无时使用 one-hot 近似)。

  4. Eagle 推测器(eagle/speculator.py):新增 _sample_draft 方法,根据 self.draft_logits 是否分配(由 draft_sample_method 决定)来选择 gumbel 采样(分配时)或 argmax 采样(未分配时)。重构 prefillgenerate_draft 方法,将重复的 gumbel_sample 调用替换为 _sample_draft,消除代码重复并集中控制 draft 采样策略。

文件 模块 状态 重要度
vllm/v1/worker/gpu/spec_decode/probabilistic_rejection_sampler_utils.py 推测解码 modified 8.05
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py 推测解码 modified 7.74
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py 推测解码 modified 7.24
vllm/config/speculative.py 配置 modified 5.97

关键符号

_sample_draft _compute_block_stats_kernel _probabilistic_rejection_kernel probabilistic_rejection_sample RejectionSampler.__call__

关键源码片段

vllm/v1/worker/gpu/spec_decode/rejection_sampler.py core-logic

删除了 strict_rejection_sample,简化拒绝采样器控制流,统一使用 probabilistic_rejection_sample。

class RejectionSampler:
    def __call__(
        self,
        logits: torch.Tensor,
        input_batch: InputBatch,
        draft_logits: torch.Tensor | None = None,
    ) -> SamplerOutput:
        draft_sampled = input_batch.input_ids[input_batch.logits_indices]
        num_nans = get_num_nans(logits) if self.sampler.compute_nans else None
​
        if self.rejection_sample_method == "standard":
            # 'standard' 方法统一使用 probabilistic_rejection_sample
            # 该函数内部根据 draft_logits 是否为 None 自动选择合适的 rejection 策略
            pos = input_batch.positions[input_batch.logits_indices]
            processed_logits = self.sampler.apply_sampling_params(
                logits, input_batch.expanded_idx_mapping,
                input_batch.idx_mapping_np, pos, draft_sampled,
                input_batch.expanded_local_pos,
            )
            sampled, num_sampled = probabilistic_rejection_sample(
                processed_logits, draft_logits, draft_sampled,
                input_batch.cu_num_logits, pos,
                input_batch.idx_mapping, input_batch.expanded_idx_mapping,
                input_batch.expanded_local_pos,
                self.sampler.sampling_states.temperature.gpu,
                self.sampler.sampling_states.seeds.gpu,
                self.num_speculative_steps,
            )
            # 计算 logprobs 时使用 processed 或 raw logits
            logprobs_tensors = self._get_logprobs_tensors(...)
        elif self.rejection_sample_method == "synthetic":
            # synthetic 方法使用预定义的接受率,不依赖实际 logits
            ...
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py core-logic

新增 _sample_draft 方法封装 draft 采样逻辑,根据配置选择 gumbel 或 argmax,重构预填充与 decode 流程。

def _sample_draft(
    self,
    logits: torch.Tensor,
    idx_mapping: torch.Tensor,
    pos: torch.Tensor,
    step: int,
) -> torch.Tensor:
    # 根据 draft_logits 是否分配(由 draft_sample_method 控制)
    # 选择采样策略
    if self.draft_logits is not None:
        # 'gumbel' 模式:使用 Gumbel 噪声随机采样,
        # 并将原始 logits 保存到 draft_logits 供 rejection sampling 使用
        return gumbel_sample(
            logits,
            idx_mapping,
            self.temperature,
            self.seeds,
            pos + 1, # 注意:位置加 1 以使 Gumbel 噪声与目标采样同步
            apply_temperature=True,
            processed_logits_out=self.draft_logits[:, step],
        )
    else:
        # 'greedy' 模式(默认):直接取 argmax,
        # 等价于将 draft 分布视为 one-hot,与 MRV1 行为一致
        return logits.argmax(dim=-1)

评论区精华

强制 argmax 限制随机采样 设计

gemini-code-assist[bot] 指出:当 cache_draft_logits 为 False 时强制使用 argmax,不允许随机 drafting 但允许 one-hot rejection。这可能限制灵活性。

结论:PR 作者保留设计:greedy 模式以内存效率优先,后续可扩展。评论显示为设计权衡。 · acknowledged

数学正确性检查 正确性

benchislett 评论:“Need to double-check the math but at a high level it looks good”。表明 reviewer 关注数学推导,但整体认可。

结论:未进一步讨论,PR 被合并,大概率数学正确。 · acknowledged

风险与影响

  1. 配置兼容性rejection_sample_method 旧值 'strict''probabilistic' 被移除,用户需要更新配置。
  2. 核心路径重构:删除了 strict_rejection_sample Triton 内核,所有 MRV2 推测解码路径都改为 probabilistic_rejection_sample,可能引入回归。
  3. 默认行为变更:默认值从 'strict' 变为 'standard'(等价旧 'probabilistic' + 'greedy' draft),用户可能观察到不同的采样行为。
  4. Triton 内核正确性:新增的 HAS_DRAFT_LOGITS 编译常量可能导致条件分支未覆盖的边界情况。
  5. 作用域局限draft_sample_method 当前仅影响 MRV2 的 Eagle 模型,其他推测方法或 MRV1 不受影响。

用户:使用 MRV2 + speculative decoding 的用户将看到接受率提升(benchmark 显示从 17.55% 提升至 20.63%),输出吞吐量提升约 14‑18%。需要更新 config 中的 rejection_sample_method 键名。

系统:默认不分配 draft logits 张量(draft_logits = None),减少了 GPU 内存占用;Triton 内核绕过了不必要的计算,提升效率。

团队:代码结构更清晰,删除了重复的 strict rejection 逻辑,后续维护和优化更简便。

配置兼容性变更 核心路径重构 Triton 内核逻辑变更 默认行为变更

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论