Prhub

#26768 Refactor simulated acceptance length generation

原始 PR 作者 JonnyKong 合并时间 2026-06-04 09:31 文件变更 1 提交数 1 评论 3 代码增减 +27 / -16

执行摘要

提取模拟接受长度采样函数以复用

方便其他推测解码算法复用模拟接受长度生成逻辑。

值得合入,提高代码复用性。建议关注 _sample_simulated_acc_len 的用户,并考虑添加单元测试覆盖。

讨论亮点

无 review 评论。自动机器人确认了代码重构并指出错误消息修复。

实现拆解

  1. python/sglang/srt/speculative/spec_utils.py 中新增 _sample_simulated_acc_len(simulate_acc_len, simulate_acc_method, max_len) 函数,返回采样后的整数长度。
  2. generate_simulated_accept_index 中原有的采样代码替换为对 _sample_simulated_acc_len 的调用。
  3. 将硬编码的 spec_steps + 1 替换为参数 max_len,使其更通用。
  4. 修复错误消息中使用 SIMULATE_ACC_METHOD 全局常量而非 simulate_acc_method 参数的 bug。
文件 模块 状态 重要度
python/sglang/srt/speculative/spec_utils.py 推测解码 modified 7.01

关键符号

_sample_simulated_acc_len generate_simulated_accept_index

关键源码片段

python/sglang/srt/speculative/spec_utils.py core-logic

核心重构文件,提取采样逻辑为独立函数,修复错误消息 bug。

# python/sglang/srt/speculative/spec_utils.pydef _sample_simulated_acc_len(
    simulate_acc_len: float,
    simulate_acc_method: str,
    max_len: int,
) -> int:
    """
    Sample a simulated acceptance length in [1, max_len].    提取自 generate_simulated_accept_index,供其他推测解码算法复用。
    max_len 替代原来的 spec_steps + 1,更加通用。
    """
    if simulate_acc_method == "multinomial":
        simulated_values = torch.normal(
            mean=simulate_acc_len,
            std=1.0,
            size=(1,),
            device="cpu",
        )
        # clamp simulated values to be between 1 and max_len
        simulated_values = torch.clamp(simulated_values, min=1.0, max=max_len)
        simulate_acc_len = int(simulated_values.round().item())
    elif simulate_acc_method == "match-expected":
        simulate_acc_len = max(1.0, min(max_len, simulate_acc_len))
        lower = int(simulate_acc_len // 1)
        upper = lower + 1 if lower < max_len else lower
        if lower == upper:
            simulate_acc_len = lower
        else:
            weight_upper = simulate_acc_len - lower
            weight_lower = 1.0 - weight_upper
            probs = torch.tensor([weight_lower, weight_upper], device="cpu")
            sampled_index = torch.multinomial(probs, num_samples=1)
            simulate_acc_len = lower if sampled_index == 0 else upper
    else:
        raise ValueError(f"Invalid simulate_acc_method: {simulate_acc_method}")
    return int(simulate_acc_len)
​
​
def generate_simulated_accept_index(
    accept_index,
    predict,
    num_correct_drafts,
    bs,
    spec_steps,
    simulate_acc_len: float = SIMULATE_ACC_LEN,
    simulate_acc_method: str = SIMULATE_ACC_METHOD,
):
    assert simulate_acc_len > 0.0
    # 使用提取的函数
    simulate_acc_len = _sample_simulated_acc_len(
        simulate_acc_len, simulate_acc_method, spec_steps + 1
    )
    # ... 后续逻辑不变

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

风险较低。提取的函数逻辑与原来一致,仅改变组织方式。参数 max_len 替代 spec_steps + 1 需要调用方确保传入正确值。

对现有功能无影响,重构后 generate_simulated_accept_index 行为不变。未来其他推测解码算法可直接调用 _sample_simulated_acc_len

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论