Prhub

#41775 [Model Runner V2] FP32 gumbel sampling.

原始 PR 作者 PatchouliTIS 合并时间 2026-05-16 00:20 文件变更 8 提交数 17 评论 10 代码增减 +50 / -7

执行摘要

Gumbel 采样默认使用 FP32 以提升性能

根据 PR body,H20 上的 profile 结果促使优化 Gumbel 采样精度行为。FP64 的使用会降低采样吞吐,而经验上 FP32 对于 Gumbel-max 足够。目标是使 FP64 可选,FP32 成为默认快速选项,并避免 Uniform 为零导致的数值问题。

值得精读学习如何在 Triton 内核中安全切换 FP32/FP64 并处理边界值;以及从环境变量演化到引擎标志的设计决策过程,体现了代码的健壮性和可维护性。

讨论亮点
  • WoosukKwon 建议将临时环境变量改为正式的引擎标志 --use-fp64-gumbel,作者采纳。
  • gemini-code-assist[bot] 指出 _gumbel_sample_kernel 未添加 USE_FP64 参数会导致运行时 TypeError,该问题在后续修复中解决。
  • TheEpicDolphin 要求确认 draft acceptance rates 在使用 FP32 后不会退化,作者确认无问题。同时提出了一个代码风格 nit(将 gumbel_noise 计算提取出分支),作者已修改。
  • TheEpicDolphin 还注意到一个多余文件(pre-commit 输出)被误加入,作者移除。

实现拆解

  1. 添加配置项:在 vllm/config/model.pyModelConfig 中新增 use_fp64_gumbel 布尔字段,默认关闭,不影响计算图。
  2. 暴露 CLI 参数:在 vllm/engine/arg_utils.pyEngineArgs 中添加对应字段,并注册 --use-fp64-gumbel CLI 参数,将其传入 ModelConfig
  3. 修改 Gumbel 采样内核:在 vllm/v1/worker/gpu/sample/gumbel.py 中新增 _FP32_TINY clamp 常量,为 gumbel_block_argmax_gumbel_sample_kernel 添加 USE_FP64 编译时常量参数,根据其值选择 FP64/FP32 路径,并调整 local_max 张量的 dtype。
  4. 修改拒绝采样内核:在 vllm/v1/worker/gpu/spec_decode/rejection_sampler_utils.py 中为 _resample_kernel 添加 USE_FP64 参数,并传递给 gumbel_block_argmax,同时调整 resampled_local_max 的 dtype。
  5. 传播配置:在 SamplerModelRunner 以及推测解码模块的 RejectionSamplerEagleSpeculator 中接收并传递 use_fp64_gumbel,确保内核能获得该配置。
文件 模块 状态 重要度
vllm/v1/worker/gpu/sample/gumbel.py 采样内核 modified 6.77
vllm/v1/worker/gpu/spec_decode/rejection_sampler_utils.py 推测解码 modified 5.84
vllm/config/model.py 模型配置 modified 5.74
vllm/engine/arg_utils.py 引擎参数 modified 5.2
vllm/v1/worker/gpu/sample/sampler.py 采样器 modified 4.99
vllm/v1/worker/gpu/model_runner.py 模型运行器 modified 4.96
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py 推测解码器 modified 4.89
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py 拒绝采样器 modified 4.35

关键符号

gumbel_block_argmax _gumbel_sample_kernel gumbel_sample _resample_kernel rejection_sample Sampler.__init__ ModelRunner.__init__ create_model_config

关键源码片段

vllm/v1/worker/gpu/sample/gumbel.py dependency-wiring

核心采样内核,实现了 FP32/FP64 条件分支和 clamp 逻辑

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import HAS_TRITON, tl, triton# 最小的正规 FP32 正值,用于 clamp 均匀分布采样值,
# 避免 `log(u)` 产生 -inf,从而保证 `-log(-log(u))` 始终有限。
# Triton 要求 `@triton.jit` 中访问的全局变量需用 `tl.constexpr(...)` 包裹,
# 但仅在 Triton 可用时才可调用(CPU worker 路径下 `tl` 为占位符,调用会崩溃)。
_FP32_TINY = (
    tl.constexpr(float.fromhex("0x1p-126")) if HAS_TRITON else float.fromhex("0x1p-126")
)@triton.jit
def gumbel_block_argmax(
    logits,
    block,
    mask,
    token_idx,
    expanded_idx_mapping_ptr,
    temp_ptr,
    seeds_ptr,
    pos_ptr,
    processed_logits_ptr,
    processed_logits_stride,
    processed_logits_col_ptr,
    vocab_size,
    APPLY_TEMPERATURE: tl.constexpr,
    USE_FP64: tl.constexpr, # 新增参数:是否使用 FP64(默认 FP32)
):
    req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
    temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
    if temp != 0.0 and APPLY_TEMPERATURE:
        logits = logits / temp
​
    # ... 处理 processed_logits 存储(省略)...
​
    # FP32 是默认规约精度;FP64 在 H100/Ada/Blackwell 上吞吐低 32-64 倍,
    # 且经验上对于 Gumbel-max 不可区分。
    if USE_FP64:
        logits = logits.to(tl.float64)
    if temp != 0.0:
        seed = tl.load(seeds_ptr + req_state_idx)
        pos = tl.load(pos_ptr + token_idx)
        gumbel_seed = tl.randint(seed, pos)
​
        if USE_FP64:
            # 使用 64-bit 随机数构建 FP64 uniform
            u = tl_rand64(gumbel_seed, block, includes_zero=False)
        else:
            # 使用 FP32 uniform,并 clamp 防止零
            u = tl.rand(gumbel_seed, block)
            u = tl.maximum(u, _FP32_TINY)
        gumbel_noise = -tl.log(-tl.log(u))
​
        logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
​
    value, idx = tl.max(logits, axis=0, return_indices=True)
    return value, idx

评论区精华

将环境变量改为引擎标志 设计

WoosukKwon 认为该功能非临时性,建议使用 `--use-fp64-gumbel` 引擎标志而非环境变量。

结论:作者同意并修改为引擎标志。 · 已解决

缺少 USE_FP64 参数导致 runtime error 正确性

gemini-code-assist[bot] 指出 `_gumbel_sample_kernel` 未被更新以接受 `USE_FP64` 参数,会导致 TypeError。

结论:作者在后续修复中添加了该参数。 · 已解决

确认 FP32 对 draft acceptance 无退化 测试

TheEpicDolphin 要求作者 double check draft acceptance rates 在使用 FP32 后不变。

结论:作者确认无问题,审查者表示信任。 · 已解决

gumbel_noise 计算可以提出分支 style

TheEpicDolphin 提出 `gumbel_noise = -tl.log(-tl.log(u))` 可以提出 if/else 分支以减少重复。

结论:作者修改,将计算移出分支。 · 已解决

误加多余文件 other

TheEpicDolphin 发现一个 pre-commit 输出文件被误加入 PR。

结论:作者移除了该文件。 · 已解决

风险与影响

  • FP32 数值风险:较低的随机精度可能在极少数情况下导致 Gumbel 采样出现 tie,但通过 _FP32_TINY clamp 确保数值稳定,且经验上差异不可感知。用户可通过 --use-fp64-gumbel 回退。
  • 默认行为变更:默认使用 FP32 可能影响依赖原有 FP64 结果的用户,提供了显式启用 FP64 的选项来兼容。
  • 测试缺口:本 PR 未包含单元测试,仅依赖手动验证和 profile。
  • 参数传递遗漏:涉及多个内核文件,可能引入参数传递遗漏,但通过 Review 已捕获。
  • 用户影响:默认 FP32 将提升采样阶段速度,降低延迟,对推理吞吐有益。需要使用精确 FP64 的用户可通过 --use-fp64-gumbel 启用。
  • 系统影响:影响 V2 model runner 的采样和推测解码路径,其他模块不受影响。
  • 团队影响:该 PR 提供了可配置精度选项,便于未来平衡性能与精度。
核心路径变更 FP32 精度风险 缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论