Prhub

#26424 [Perf][Spec Decoding] Skip cat/topk/sort/gather in draft_forward for topk=1

原始 PR 作者 Qiaolin-Yu 合并时间 2026-06-02 06:37 文件变更 4 提交数 8 评论 12 代码增减 +164 / -26

执行摘要

跳过 topk=1 时 draft_forward 的 cat/topk/sort/gather

当 speculative_eagle_topk == 1 时,draft_forward 中的 cat(score_list).flatten → torch.topk → torch.sort → torch.gather 操作在数学上变为恒等,完全不需要执行 GPU kernel。该优化消除了 bitonicSortKVInPlace + sbtopk::gatherTopK 内核在 DRAFT_DECODE 阶段的调用,从而降低延迟。

值得精读。展示了如何利用数学等价性消除 GPU 内核调用,是性能优化的典型案例。_rebuild_topk1_chain_buffers 的设计和与自适应推测解码的配合值得关注。测试覆盖充分,可放心合入。

讨论亮点

Review 中主要讨论集中在:

  • 空 parent_list 的 dtype 问题:KPham 询问为何在 organize_draft_results 中为 parent_list 显式指定 dtype=torch.long。作者解释该张量将被下游 kernel 以 long 类型读取。
  • 自适应推测解码的关联:KPham 询问 _rebuild_topk1_chain_buffers 的存在是否因为自适应推测解码会动态改变 num_steps。作者确认是。
  • 测试框架提示:KPham 提到可能有新的 Eagle 单元测试框架适用于此测试,但作者未进一步采纳(现有测试已足够)。

实现拆解

  1. 在 EagleDraftWorker 和 StandaloneEagleDraftWorker 的 __init__ 中调用新增的 _rebuild_topk1_chain_buffers 方法,根据 cuda_graph_max_bsmax_running_requests 预分配 _topk1_parents_prealloc_topk1_score_indices_prealloc 张量。该方法断言 num_draft_tokens == num_steps + 1,确保链拓扑有效。
  2. draft_forward 中判断如果 topk == 1 且当前 batch size 不超过预分配大小,则直接使用预分配的 parent_list 和 top_scores_index,将 token_list 拼接后作为 draft_tokens,完全绕过 organize_draft_results 调用的内核操作。否则,回退到慢速路径。
  3. 同时在 organize_draft_results 中添加了 maybe_detect_oob 越界检查,并将空 parent_list 显式指定 dtype=torch.long,与底层内核期望一致。
  4. 新增单元测试 TestEagleWorkerV2Topk1FastPath,通过构造模拟数据验证 fast path 输出与 slow path 完全等价,覆盖 num_steps = 1..4,并测试非法参数时的断言。
文件 模块 状态 重要度
python/sglang/srt/speculative/eagle_worker_v2.py 推测解码 modified 7.58
test/registered/unit/spec/test_eagle_worker_v2_topk1_fastpath.py 推测测试 added 7.43
python/sglang/srt/speculative/standalone_worker_v2.py 推测解码 modified 5.53
python/sglang/srt/speculative/eagle_utils.py 推测工具 modified 5.54

关键符号

_rebuild_topk1_chain_buffers organize_draft_results

关键源码片段

python/sglang/srt/speculative/eagle_worker_v2.py core-logic

核心变更文件。添加了 _rebuild_topk1_chain_buffers 方法,在 __init__ 中预分配 topk=1 链的父关系和索引常量;修改 draft_forward 在 topk==1 且 batch 大小落在预分配范围内时跳过 cat/topk/sort/gather,直接使用预分配结果。

# 在 __init__ 中预分配 topk=1 链常量
self._topk1_parents_prealloc = None
self._topk1_score_indices_prealloc = None
self._rebuild_topk1_chain_buffers()def _rebuild_topk1_chain_buffers(self) -> None:
    # 当 topk=1 时,草稿树退化为链,父列表和分数索引在运行时不变
    if self.topk != 1:
        return
    # 断言:链拓扑要求 num_draft_tokens == num_steps + 1
    assert self.speculative_num_draft_tokens == self.speculative_num_steps + 1
    num_steps = self.speculative_num_steps
    sa = self.server_args
    max_bs = max(sa.cuda_graph_max_bs or 0, sa.max_running_requests or 0, 1)
    # 单步时没有父条目
    parent_width = num_steps if num_steps > 1 else 0
    self._topk1_parents_prealloc = torch.arange(
        -1, parent_width - 1, dtype=torch.long, device=self.device
    ).repeat(max_bs, 1)
    self._topk1_score_indices_prealloc = torch.arange(
        num_steps, dtype=torch.long, device=self.device
    ).repeat(max_bs, 1)# 在 draft_forward 中的快速路径分支
if self.topk == 1 and token_list[0].shape[0] <= self._topk1_parents_prealloc.shape[0]:
    # 链拓扑:parent 和 index 使用预分配常量,tokens 直接拼接
    parent_list = self._topk1_parents_prealloc[:token_list[0].shape[0]]
    top_scores_index = self._topk1_score_indices_prealloc[:token_list[0].shape[0]]
    draft_tokens = torch.cat(token_list, dim=1)
else:
    # 回退到慢速路径(调用 organize_draft_results)
    parent_list, top_scores_index, draft_tokens = organize_draft_results(
        score_list, token_list, parents_list, self.speculative_num_draft_tokens
    )
test/registered/unit/spec/test_eagle_worker_v2_topk1_fastpath.py test-coverage

新增单元测试,验证快速路径输出与慢速路径 organize_draft_results 一致,涵盖 num_steps=1..4 和断言检查。

class TestEagleWorkerV2Topk1FastPath(CustomTestCase):
    def test_fast_path_matches_slow_path(self):
        bs = 3
        for num_steps in (1, 2, 3, 4):
            with self.subTest(num_steps=num_steps):
                num_draft_tokens = num_steps + 1
                worker = _make_worker(num_steps, num_draft_tokens)
                worker._rebuild_topk1_chain_buffers()
​
                score_list, token_list, parents_list = _make_chain_lists(num_steps, bs)
                ref_parent, ref_index, ref_tokens = organize_draft_results(
                    score_list, token_list, parents_list, num_draft_tokens
                )
​
                fast_parent = worker._topk1_parents_prealloc[:bs]
                fast_index = worker._topk1_score_indices_prealloc[:bs]
                fast_tokens = torch.cat(token_list, dim=1)
​
                self.assertEqual(fast_parent.shape, ref_parent.shape)
                self.assertEqual(fast_parent.tolist(), ref_parent.long().tolist())
                self.assertEqual(fast_index.tolist(), ref_index.long().tolist())
                self.assertEqual(fast_tokens.tolist(), ref_tokens.tolist())
                # 确认是 contiguous 的 long 张量,内核通过 data_ptr 读取
                self.assertEqual(fast_parent.dtype, torch.long)
                self.assertEqual(fast_index.dtype, torch.long)
                self.assertTrue(fast_parent.is_contiguous())
                self.assertTrue(fast_index.is_contiguous())
​
    def test_assert_on_inconsistent_steps_and_draft_tokens(self):
        # num_draft_tokens 必须等于 num_steps + 1
        worker = _make_worker(num_steps=3, num_draft_tokens=3)
        with self.assertRaises(AssertionError):
            worker._rebuild_topk1_chain_buffers()

评论区精华

organize_draft_results 中空 parent_list 的 dtype 指定 正确性

KPham 询问为何在 organize_draft_results 中为 parent_list 显式指定 dtype=long。

结论:作者解释该张量将被下游 kernel 以 long 类型读取,因此必须保持一致。 · 已解决

_rebuild_topk1_chain_buffers 与自适应推测解码的关系 设计

KPham 询问 _rebuild_topk1_chain_buffers 方法的存在是否因为自适应推测解码会动态改变 speculative_num_steps。

结论:作者确认是,该方法设计为可被重新调用来适应变化的拓扑参数。 · 已解决

可用的新 Eagle 单元测试框架提示 测试

KPham 提到可能有新的 Eagle Harness 适用于此测试,暗示可以复用。

结论:作者未明确回应,当前提交的测试已充分。 · unresolved

风险与影响

  • 正确性风险:快速路径的输出必须与慢速路径一致。单元测试覆盖了 num_steps=1..4 和 batch=3 的场景,但可能遗漏其他 batch size 或更复杂的链长度。预分配大小不足时会自动回退到慢路径,不会产生错误结果。
  • 性能风险:仅影响 topk=1 路径,额外开销为构造时的一次性分配和运行时的一次形状检查,可以忽略。
  • 兼容性:非 topk=1 路径完全不变,无破坏性。
  • 安全性:添加的 OOB 检测有助于及早发现索引越界,属于安全增强。
  • 用户:配置 speculative_eagle_topk=1 的用户将获得明显的每请求延迟降低和吞吐提升;其他用户不受影响。
  • 系统:draft_forward 的 GPU 时间减少,可能释放 GPU 资源用于其他请求。
  • 团队:维护成本低,但后续如果修改 speculative_num_steps 等参数,必须重新调用 _rebuild_topk1_chain_buffers(当前在构造时调用一次;自适应推测解码可能需要额外触发)。
核心路径变更 自适应集成需手动重建 测试覆盖主要路径

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论