执行摘要
提取模拟接受长度采样函数以复用
方便其他推测解码算法复用模拟接受长度生成逻辑。
值得合入,提高代码复用性。建议关注 _sample_simulated_acc_len 的用户,并考虑添加单元测试覆盖。
无 review 评论。自动机器人确认了代码重构并指出错误消息修复。
方便其他推测解码算法复用模拟接受长度生成逻辑。
值得合入,提高代码复用性。建议关注 _sample_simulated_acc_len 的用户,并考虑添加单元测试覆盖。
无 review 评论。自动机器人确认了代码重构并指出错误消息修复。
python/sglang/srt/speculative/spec_utils.py 中新增 _sample_simulated_acc_len(simulate_acc_len, simulate_acc_method, max_len) 函数,返回采样后的整数长度。generate_simulated_accept_index 中原有的采样代码替换为对 _sample_simulated_acc_len 的调用。spec_steps + 1 替换为参数 max_len,使其更通用。SIMULATE_ACC_METHOD 全局常量而非 simulate_acc_method 参数的 bug。| 文件 | 模块 | 状态 | 重要度 |
|---|---|---|---|
python/sglang/srt/speculative/spec_utils.py |
推测解码 | modified | 7.01 |
python/sglang/srt/speculative/spec_utils.py
core-logic
核心重构文件,提取采样逻辑为独立函数,修复错误消息 bug。
# python/sglang/srt/speculative/spec_utils.py
def _sample_simulated_acc_len(
simulate_acc_len: float,
simulate_acc_method: str,
max_len: int,
) -> int:
"""
Sample a simulated acceptance length in [1, max_len].
提取自 generate_simulated_accept_index,供其他推测解码算法复用。
max_len 替代原来的 spec_steps + 1,更加通用。
"""
if simulate_acc_method == "multinomial":
simulated_values = torch.normal(
mean=simulate_acc_len,
std=1.0,
size=(1,),
device="cpu",
)
# clamp simulated values to be between 1 and max_len
simulated_values = torch.clamp(simulated_values, min=1.0, max=max_len)
simulate_acc_len = int(simulated_values.round().item())
elif simulate_acc_method == "match-expected":
simulate_acc_len = max(1.0, min(max_len, simulate_acc_len))
lower = int(simulate_acc_len // 1)
upper = lower + 1 if lower < max_len else lower
if lower == upper:
simulate_acc_len = lower
else:
weight_upper = simulate_acc_len - lower
weight_lower = 1.0 - weight_upper
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
sampled_index = torch.multinomial(probs, num_samples=1)
simulate_acc_len = lower if sampled_index == 0 else upper
else:
raise ValueError(f"Invalid simulate_acc_method: {simulate_acc_method}")
return int(simulate_acc_len)
def generate_simulated_accept_index(
accept_index,
predict,
num_correct_drafts,
bs,
spec_steps,
simulate_acc_len: float = SIMULATE_ACC_LEN,
simulate_acc_method: str = SIMULATE_ACC_METHOD,
):
assert simulate_acc_len > 0.0
# 使用提取的函数
simulate_acc_len = _sample_simulated_acc_len(
simulate_acc_len, simulate_acc_method, spec_steps + 1
)
# ... 后续逻辑不变
当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。
风险较低。提取的函数逻辑与原来一致,仅改变组织方式。参数 max_len 替代 spec_steps + 1 需要调用方确保传入正确值。
对现有功能无影响,重构后 generate_simulated_accept_index 行为不变。未来其他推测解码算法可直接调用 _sample_simulated_acc_len。
当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。
参与讨论