Prhub

#22128 Allow piecewise CUDA graph with speculative decoding

原始 PR 作者 narutolhy 合并时间 2026-04-17 13:39 文件变更 4 提交数 12 评论 45 代码增减 +272 / -18

执行摘要

允许 PCG 与所有投机解码算法共存

PCG 和投机解码操作在独立的前向路径上:PCG 处理 prefill/extend 路径(ForwardMode.EXTEND),投机解码使用 decode CUDA graphs 或 eager 执行。原有限制是 #16331 作为保守安全措施添加的,未经兼容性测试。本 PR 旨在解除限制并确保安全共存,让用户同时获得 PCG 的 prefill 加速和投机解码的 decode 加速。

值得精读。该 PR 展示了如何通过运行时安全检查而非全局禁用实现功能兼容,设计思路清晰。重点关注:

  • can_run 中两个守卫条件的语义(ForwardMode.TARGET_VERIFYcapture_hidden_mode 匹配)。
  • Draft 模型跳过 PCG 初始化的逻辑及其对多模型架构的影响。
  • 测试文件中如何编排多 GPU 环境和内存限制。
    该 PR 的演变过程(从简单移除到逐步修复兼容问题)也提供了良好的工程实践参考。
讨论亮点

Review 中核心讨论包括:

  • 设计权衡:gemini-code-assist 建议使用 set 仅允许 NEXTN 算法(if self.speculative_algorithm not in {"NEXTN"}: disable),但 Oasis-Git 建议直接移除整个 speculative 检查,作者采纳后者,因为所有投机算法均兼容。
  • Draft 模型初始化崩溃:作者发现 draft 模型的 forward 在 PCG 预热时返回 None,导致崩溃(提交 6c61f7e),通过 model_runner.py 中增加 is_draft_worker 检查修复。
  • Capture hidden mode 不匹配:EAGLE3 测试中 PCG 以 CaptureHiddenMode.NULL 捕获,但投机解码的 extend batch 设置了 CaptureHiddenMode.FULL,导致 PCG 回放错误(提交 df32a26),通过在 can_run 中比较 capture_hidden_mode 修复。
  • PCG 是否实际生效:Chen-0210 指出默认 flags 下 PCG 可能未实际执行,因为缺少 --enable-return-hidden-states 来启用 full capture_mode;作者确认需要跟进优化。

实现拆解

步骤 1:移除 server_args.py 中的 blanket 禁用
_handle_piecewise_cuda_graph 方法中删除 if self.speculative_algorithm is not None: self.disable_piecewise_cuda_graph = True 及相关注释,并将后续条件编号依次前移。

步骤 2:在 piecewise_cuda_graph_runner.pycan_run 增加安全守卫
新增两个检查:

  • 如果 forward_batch.forward_mode.is_target_verify(),返回 False——PCG 图以 ForwardMode.EXTEND 捕获,不能用于 TARGET_VERIFY 模式。
  • 如果 forward_batch.capture_hidden_mode != self.capture_hidden_mode,返回 False——避免 PCG 回放返回错误或缺失的 hidden_states。

步骤 3:在 model_runner.py 跳过 draft 模型的 PCG 初始化
init_piecewise_cuda_graphs 开头增加 if self.is_draft_worker: return,因为 draft 模型使用 decode CUDA graphs 而不是 PCG。

步骤 4:新增端到端测试文件
创建 test/registered/piecewise_cuda_graph/test_pcg_with_speculative_decoding.py,包含四个测试类:TestPCGWithMTPTestPCGWithEAGLE3TestPCGWithSTANDALONETestPCGWithNGRAM。每个类启动对应配置的服务器,运行 GSM8K 评测验证精度,并检查平均投机接受长度。

步骤 5:调整测试资源参数
将 EAGLE3 测试的 mem_fraction_static 从 0.65 降至 0.55,避免 PCG 图捕获与 decode CUDA graphs 同时占用过多 GPU 内存导致 OOM。

文件 模块 状态 重要度
test/registered/piecewise_cuda_graph/test_pcg_with_speculative_decoding.py PCG 测试 added 7.76
python/sglang/srt/server_args.py 服务器配置 modified 7.17
python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py PCG 执行器 modified 6.53
python/sglang/srt/model_executor/model_runner.py 模型运行器 modified 5.68

关键符号

_handle_piecewise_cuda_graph can_run init_piecewise_cuda_graphs

关键源码片段

test/registered/piecewise_cuda_graph/test_pcg_with_speculative_decoding.py test-coverage

新增测试文件,验证 PCG 与 NEXTN/EAGLE3/STANDALONE/NGRAM 四种投机解码算法的兼容性,是功能正确性的关键证明。

"""Test piecewise CUDA graph coexisting with speculative decoding.PCG handles prefill/extend path while speculative decoding (MTP/EAGLE3/STANDALONE/NGRAM)
uses decode CUDA graphs. This test verifies they don't interfere with each other.
"""import unittest
from types import SimpleNamespaceimport requestsfrom sglang.srt.utils import kill_process_tree
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    popen_launch_server,
)register_cuda_ci(est_time=600, suite="stage-b-test-2-gpu-large")
​
​
class TestPCGWithMTP(unittest.TestCase):
    """Test PCG + MTP (NEXTN) on Qwen3.5-35B-A3B with FP8."""
​
    @classmethod
    def setUpClass(cls):
        cls.model = "Qwen/Qwen3.5-35B-A3B"
        cls.base_url = DEFAULT_URL_FOR_TEST
        other_args = [
            "--tp",
            "2",
            "--trust-remote-code",
            "--quantization",
            "fp8",
            "--mamba-scheduler-strategy",
            "extra_buffer",
            "--enable-piecewise-cuda-graph", # 启用 PCG
            "--speculative-algorithm",
            "NEXTN", # 使用 NEXTN 投机解码
            "--reasoning-parser",
            "qwen3",
        ]
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH * 3,
            other_args=other_args,
        )
​
    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)
​
    def test_gsm8k(self):
        # 使用 GSM8K 评测验证精度
        args = SimpleNamespace(
            base_url=self.base_url,
            model=self.model,
            eval_name="gsm8k",
            max_tokens=8192,
            num_examples=200,
            num_threads=200,
            thinking_mode="qwen3",
        )
        metrics = run_eval(args)
        print(metrics)
        # 验证精度 > 0.75(正常水平)
        self.assertGreater(metrics["score"], 0.75)
​
        # 检查平均投机接受长度 > 1.5,确保投机解码正常工作
        server_info = requests.get(self.base_url + "/server_info").json()
        avg_spec_accept_length = server_info["internal_states"][0][
            "avg_spec_accept_length"
        ]
        print(f"{avg_spec_accept_length=}")
        self.assertGreater(avg_spec_accept_length, 1.5)

其他三个测试类(EAGLE3/STANDALONE/NGRAM)结构类似,使用不同模型和参数组合。

python/sglang/srt/server_args.py core-logic

核心变更入口:移除 `_handle_piecewise_cuda_graph` 中对 speculative decoding 的 blanket 禁用,使 PCG 与所有投机算法兼容。

def _handle_piecewise_cuda_graph(self):
    # Skip auto-disable when enforce flag is set (for testing)
    if self.enforce_piecewise_cuda_graph:
        self.disable_piecewise_cuda_graph = False
        return
​
    # Disable piecewise cuda graph with following conditions:
    # 1. Disable Model Arch
    if self.get_model_config().is_piecewise_cuda_graph_disabled_model:
        self.disable_piecewise_cuda_graph = True
    # 2. DP attention ( 原 #2 号 speculative 条件已移除 )
    if self.enable_dp_attention:
        self.disable_piecewise_cuda_graph = True
    # 3. Torch compile
    if self.enable_torch_compile:
        self.disable_piecewise_cuda_graph = True
    # 4. Pipeline parallelism
    if self.pp_size > 1:
        self.disable_piecewise_cuda_graph = True
    # 5. Non-CUDA hardware (AMD, NPU, CPU, MPS, XPU, etc.)
    if is_hip() or is_npu() or is_cpu() or is_mps() or is_xpu():
        self.disable_piecewise_cuda_graph = True
    # 6. MoE A2A backend
    if self.moe_a2a_backend != "none":
        self.disable_piecewise_cuda_graph = True
    # 7. LoRA
    if self.lora_paths or self.enable_lora:
        self.disable_piecewise_cuda_graph = True
    # 8. Multimodal / VLM models
    if self.get_model_config().is_multimodal:
        self.disable_piecewise_cuda_graph = True
    # 9. GGUF quantized models
    if (
        self.load_format == "gguf"
        or self.quantization == "gguf"
        or check_gguf_file(self.model_path)
    ):
        self.disable_piecewise_cuda_graph = True
    # 10. DLLM models
    if self.dllm_algorithm is not None:
        self.disable_piecewise_cuda_graph = True
    # 11. CPU offload
    if self.cpu_offload_gb > 0 or self.enable_hierarchical_cache:
        self.disable_piecewise_cuda_graph = True
    # 12. Deterministic inference
    if self.enable_deterministic_inference:
        self.disable_piecewise_cuda_graph = True
    # 13. PD disaggregation
    if self.disaggregation_mode != "null":
        self.disable_piecewise_cuda_graph = True
    # 14. Symmetric memory
    if self.enable_symm_mem:
        self.disable_piecewise_cuda_graph = True
    # 15. Expert distribution recorder
    if self.enable_eplb or self.expert_distribution_recorder_mode is not None:
        self.disable_piecewise_cuda_graph = True
    # 16. Context parallel
    if self.attn_cp_size > 1:
        self.disable_piecewise_cuda_graph = True
python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py data-contract

关键安全逻辑:在 `can_run` 中增加 `TARGET_VERIFY` 和 `capture_hidden_mode` 检查,确保 PCG 只在匹配的前向模式下回放。

def can_run(self, forward_batch: ForwardBatch):
    # Disable piecewise cuda graph for input embeddings
    # TODO(yuwei): fix it
    if forward_batch.input_embeds is not None:
        return False
​
    # PCG graphs are captured with ForwardMode.EXTEND and spec_info=None.
    # TARGET_VERIFY has different spec_info and capture_hidden_mode,
    # so it must not use PCG-captured graphs.
    if forward_batch.forward_mode.is_target_verify():
        return False
​
    # PCG graphs are captured with the runner's capture_hidden_mode.
    # If the batch needs a different mode (e.g. FULL for speculative
    # decoding), PCG replay would return wrong/missing hidden_states.
    if forward_batch.capture_hidden_mode != self.capture_hidden_mode:
        return False
​
    # Disable for token embedding overrides (dynamic per-request)
    if forward_batch.replace_embeds is not None:
        return False
​
    num_tokens = len(forward_batch.input_ids)
    if forward_batch.return_logprob:
        for start_len, seq_len in zip(
            forward_batch.extend_logprob_start_lens_cpu,
            forward_batch.extend_seq_lens_cpu,
        ):
            if start_len is not None and start_len < seq_len:
                return False
    if num_tokens <= self.max_num_tokens:
        return True
    return False

评论区精华

是否使用 set 限制算法 vs 完全移除 设计

gemini-code-assist 建议使用 set 仅允许 NEXTN 算法;Oasis-Git 建议直接移除整个 speculative 检查,因为所有算法均兼容。

结论:作者同意 Oasis-Git,直接移除条件,未保留算法白名单。 · 已解决

Draft 模型 PCG 初始化崩溃 正确性

作者发现 draft 模型的 forward 在 PCG 预热时返回 None,导致 crash。

结论:在 `model_runner.py` 中增加 `self.is_draft_worker` 检查,跳过 draft 模型的 PCG 初始化(提交 6c61f7e)。 · 已解决

Capture hidden mode 不匹配导致 EAGLE3 错误 正确性

EAGLE3 使用 `CaptureHiddenMode.FULL` 获取 hidden states,但 PCG 以 `NULL` 模式捕获,回放时返回错误结果。

结论:在 `can_run` 中增加 `capture_hidden_mode` 比较检查,确保模式一致(提交 df32a26)。 · 已解决

PCG 在实际场景中是否真正生效 question

Chen-0210 指出默认 flags 下 `capture_hidden_mode` 可能为 NULL,但投机解码需要 FULL 模式,导致 PCG 回退。作者之前的基准测试可能未实际启用 PCG。

结论:作者承认需要跟进优化,计划添加 `--enable-return-hidden-states` 支持以确保 PCG 在 MTP 下生效。 · 待处理

风险与影响

主要风险:

  1. PCG 与投机解码的交互:虽然修复了已知问题,但仍可能有未发现的 corner case,尤其 draft 模型的 prefill 路径尚不支持 PCG(作者正在研究)。
  2. GPU 内存增加:PCG 图捕获额外消耗显存,在高负载或多测试并行时可能 OOM(测试中已降低 mem_fraction_static 缓解)。
  3. 默认未生效:如 Chen-0210 所述,默认配置下 PCG 可能因 capture_hidden_mode 不匹配而回退,用户需显式启用 --enable-return-hidden-states 才能获得完整加速。
  4. 测试覆盖局限:测试基于 GSM8K 评测和特定模型配置,未覆盖所有可能的输入形状和并发场景。

正面影响

  • 用户现在可同时启用 PCG 和任意投机解码算法,获得 prefill 加速(TTFT 降低 42%)和 decode 加速的双重收益。
  • 代码架构更加清晰,将安全性从全局禁用转嫁为运行时细粒度检查,便于未来扩展。

负面影响

  • 增加 GPU 内存压力,需要用户调整 mem_fraction_static 等参数。
  • 默认配置下 PCG 可能未生效,需要用户理解相关 flags(--enable-return-hidden-states)并正确设置。
  • 维护成本增加:新增的守卫条件需与新 forward mode 同步维护。
核心路径变更 新增功能分支 GPU 内存增加 测试覆盖主流算法

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论