Prhub

#43150 [BUG] Fix FP64 Gumbel precision coverage

原始 PR 作者 tianyu-z 合并时间 2026-06-05 19:04 文件变更 11 提交数 23 评论 10 代码增减 +391 / -21

执行摘要

修复 FP64 Gumbel 精度未覆盖 V1 采样路径

在 GPU 上,fp32 指数随机数生成无法表示极小的下尾事件(低于 2^-24)。指数竞争形式的 Gumbel-max 采样(q.exponential_(); probs.div(q).argmax())等价于显式 Gumbel 采样,因此也受此问题影响。对于宽分布/多并行分类竞争,缺失的下尾可能会系统性移除稀有 token 的获胜机会。本 PR 旨在将 use_fp64_gumbel 的覆盖范围扩展到 V1 采样、 rejection sampler 和 draft proposer 的所有指数竞争路径。

值得精读,尤其是 topk_topp_sampler.py 中的辅助函数设计和 sample_with_exponential_noise 的 dtype 处理逻辑。本 PR 展示了如何系统地修复一个隐藏的精度 bug,并在多个采样路径中保证一致性。对于关注采样精度和公平性的开发者具有参考价值。

讨论亮点

Reviewer njhill 在总体评论中指出团队正在迁移至 MRV2,不计划对 MRV1 做此类增强,但认为改动扎实并最终批准。此外,njhill 对 vllm/v1/worker/gpu/sample/gumbel.py 中的注释修改提出疑问,作者随后还原了该文件的修改。njhill 还对 empty_exponential_noise_likesample_with_exponential_noise 的函数签名格式提出了简化建议,作者均已采纳。

实现拆解

  1. 参数传递链条:在 vllm/config/model.py 中更新 use_fp64_gumbel 的文档字符串,明确其影响范围。在 vllm/v1/worker/gpu_model_runner.py 中从 ModelConfig 读取标志并传递给 Sampler。在 vllm/v1/sample/sampler.py__init__ 中添加 use_fp64_gumbel 参数,并传递给 TopKTopPSampler

  2. 核心采样逻辑修改:在 vllm/v1/sample/ops/topk_topp_sampler.py 中:

    • 新增 empty_exponential_noise_like 函数,根据 use_fp64_gumbel 返回 float64 或原生 dtype 的空张量。
    • 新增 sample_with_exponential_noise 函数,当噪声 dtype 与概率不一致时先取倒数再相乘,避免精度损失。
    • 修改 forward_nativeforward_cpuforward_cudaforward_hip 方法,在 use_fp64_gumbel=True 时强制走原生路径并使用 fp64 噪声。
    • 对于 FlashInfer 和 ROCm aiter 后端,当 use_fp64_gumbel=True 时也回退到原生实现。
  3. 精细化解码路径填充:在 vllm/v1/sample/rejection_sampler.pysample_recovered_tokens 函数中添加 use_fp64_gumbel 参数,控制噪声 dtype。在 vllm/v1/spec_decode/llm_base_proposer.pycompute_probs_and_sample_next_token 中添加 use_fp64_gumbel 参数,传递到内部的指数噪声采样。

  4. 测试与证明:新增 4 个单元测试,分别验证标志线程、random_samplesample_recovered_tokenscompute_probs_and_sample_next_tokenuse_fp64_gumbel=True 时产生与显式 fp64 参考实现一致的结果。新增 tools/gumbel_precision/prove_exponential_race_precision.py,作为 CUDA 统计证明脚本。

文件 模块 状态 重要度
vllm/v1/sample/ops/topk_topp_sampler.py 采样器 modified 6.15
vllm/v1/sample/sampler.py 采样器 modified 6.16
tests/v1/sample/test_topk_topp_sampler.py 测试 modified 6.56

关键符号

Sampler.__init__ TopKTopPSampler.__init__ empty_exponential_noise_like sample_with_exponential_noise random_sample sample_recovered_tokens compute_probs_and_sample_next_token

关键源码片段

vllm/v1/sample/ops/topk_topp_sampler.py core-logic

核心采样逻辑修改,新增 fp64 指数噪声辅助函数,并修改多个 forward 方法以支持 use_fp64_gumbel 分支。

def empty_exponential_noise_like(
    probs: torch.Tensor,
    use_fp64_gumbel: bool,
) -> torch.Tensor:
    # 如果启用 fp64 gumbel,则创建 float64 空张量,否则使用 probs 的 dtype(通常是 float32)
    dtype = torch.float64 if use_fp64_gumbel else probs.dtype
    return torch.empty(probs.shape, dtype=dtype, device=probs.device)def sample_with_exponential_noise(
    probs: torch.Tensor,
    q: torch.Tensor,
) -> torch.Tensor:
    # 当噪声 dtype 与 probs 不一致时,先 reciprocal 再 multiply 避免 fp64 / fp32 直接除法的精度损失
    if q.dtype == probs.dtype:
        scores = probs.div_(q)
    else:
        scores = q.reciprocal_()
        scores.mul_(probs)
    return scores.argmax(dim=-1).view(-1)
vllm/v1/sample/sampler.py core-logic

Sampler 构造函数接受 use_fp64_gumbel 并传递给 TopKTopPSampler,是参数传递的关键环节。

class Sampler(nn.Module):
    # ... 其他代码 ...
    def __init__(
        self,
        logprobs_mode: LogprobsMode = "raw_logprobs",
        use_fp64_gumbel: bool = False, # 新增参数,用于控制是否使用 fp64 指数噪声
    ):
        super().__init__()
        # 将标志传递给 TopKTopPSampler,确保下游采样路径生效
        self.topk_topp_sampler = TopKTopPSampler(logprobs_mode, use_fp64_gumbel)
        self.pin_memory = is_pin_memory_available()
        self.logprobs_mode = logprobs_mode
        self.use_fp64_gumbel = use_fp64_gumbel
tests/v1/sample/test_topk_topp_sampler.py test-coverage

新增测试验证标志线程和 fp64 随机采样,是回归测试的重要组成部分。

def test_random_sample_uses_fp64_exponential_race_when_requested():
    torch.set_default_device(DEVICE_TYPE)
    probs = torch.tensor(
        [[0.70, 0.20, 0.10],
         [0.05, 0.15, 0.80],
         [0.25, 0.25, 0.50]],
        dtype=torch.float32,
        device=DEVICE_TYPE,
    )
    _seed_default_generator(12345)
    # 手动构造参考结果:fp64 指数噪声
    q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device)
    q.exponential_()
    expected = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1)
​
    _seed_default_generator(12345)
    # 调用实际函数,并启用 fp64
    actual = random_sample(probs.clone(), {}, use_fp64_gumbel=True)
​
    assert torch.equal(actual, expected)

评论区精华

对 MRV1 增强的总体评价 设计

Reviewer njhill 指出团队正在迁移至 MRV2,不计划对 MRV1 做此类增强,但认为改动扎实并最终批准。

结论:团队认可改动,合并到 MRV1。 · 已解决

注释修改建议 style

njhill 对 vllm/v1/worker/gpu/sample/gumbel.py 中的注释修改提出疑问,作者随后还原了该文件的修改。

结论:注释还原,保持原样。 · 已解决

函数签名格式简化 style

njhill 对 empty_exponential_noise_like 和 sample_with_exponential_noise 的函数签名格式提出了简化建议。

结论:作者采纳建议,简化了签名格式。 · 已解决

风险与影响

  • 性能风险:启用 use_fp64_gumbel 后,float64 指数噪声生成和除法运算比 float32 慢约 1/32-1/64,在吞吐敏感场景下可能影响性能。但默认配置不变,仅显式请求的用户受影响。
  • 正确性风险:新增的 sample_with_exponential_noise 函数在 dtype 不一致时使用 reciprocal+multiply,经验证与直接除法等价。单元测试覆盖了关键路径。
  • 兼容性风险:修改涉及多个采样后端(CUDA/ROCm/CPU/FlashInfer),在 use_fp64_gumbel=True 时统一回退到原生路径,与原逻辑一致,无兼容性问题。
  • 回归风险:通过 4 个单元测试和 CUDA 统计证明脚本验证,主要路径正确性有保障。
  • 用户影响:启用 --use-fp64-gumbel 的用户现在在 V1 采样和精细化解码中也能获得 fp64 精度,改善低概率 token 的生成公平性。默认用户无感知,升级无缝。
  • 系统影响:新增了非生产性的统计证明脚本,不占用运行时资源。
  • 团队影响:统一的标志覆盖策略使所有 Gumbel 等价位路径保持一致,降低了后续维护成本。
性能降级 ( 启用时 ) fp32 下尾截断修复 核心采样路径变更

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论