Prhub

#24241 [bugfix] Support MIXED forward mode in TBO splitter for DP attention

原始 PR 作者 ch-wan 合并时间 2026-05-02 07:01 文件变更 2 提交数 1 评论 3 代码增减 +34 / -2

执行摘要

修复 DP attention 与 mixed chunk 组合时的崩溃

--enable-dp-attention--enable-mixed-chunk 组合会确定性崩溃,因为 compute_split_seq_indexcompute_split_token_index 未处理 ForwardMode.MIXED。PR body 指出该组合之前被静默破坏,server_args.py 中没有任何验证阻止它,但一旦调度器进入 MIXED 模式就会崩溃。

值得合入,修复严重崩溃 bug,改动极小且带有回归测试。可关注后续是否将 --enable-two-batch-overlap 与 mixed chunk 的支持补全。

讨论亮点

gemini-code-assist[bot] 建议将 forward_mode == ForwardMode.EXTEND or forward_mode == ForwardMode.MIXED 改为 forward_mode in (ForwardMode.EXTEND, ForwardMode.MIXED) 以提高可读性。该建议未被采纳,但属于风格优化,不影响功能正确性。

实现拆解

  1. 修改 split 逻辑:在 python/sglang/srt/batch_overlap/two_batch_overlap.pycompute_split_seq_index(第 84 行)和 compute_split_token_index(第 273 行)中,将条件从 forward_mode == ForwardMode.EXTEND 扩展为 forward_mode == ForwardMode.EXTEND or forward_mode == ForwardMode.MIXED,使 MIXED 模式复用 EXTEND 的分割逻辑。
  2. 原理说明mix_with_running 操作后,running decode 请求被追加到 extend_lens 作为长度为 1 的条目,因此 _split_extend_seqs 和累积和分割逻辑可直接复用。
  3. 添加回归测试:在 test/registered/distributed/test_dp_attention.py 中新增 TestDPAttentionMixedChunk 测试类,继承 CustomTestCaseGSM8KMixin,启动服务器时传入 --enable-dp-attention --dp 2 --enable-mixed-chunk --chunked-prefill-size 256 参数,并通过 GSM8K 准确率阈值 0.6 确保推理正确性。
文件 模块 状态 重要度
python/sglang/srt/batch_overlap/two_batch_overlap.py 调度器 modified 6.04
test/registered/distributed/test_dp_attention.py DP 注意力 modified 6.31

关键符号

compute_split_seq_index compute_split_token_index

关键源码片段

python/sglang/srt/batch_overlap/two_batch_overlap.py core-logic

核心 bugfix 文件,修改 `compute_split_seq_index` 和 `compute_split_token_index` 以处理 MIXED forward mode

def compute_split_seq_index(
    forward_mode: ForwardMode,
    num_tokens: int,
    extend_lens: Optional[Sequence[int]],
    token_num_per_seq: Optional[int],
) -> Optional[int]:
    # 关键变更:将 MIXED 模式视为 EXTEND,因为 mix_with_running 后
    # running decode 请求被追加为长度 1 的 extend_lens
    if forward_mode == ForwardMode.EXTEND or forward_mode == ForwardMode.MIXED:
        assert extend_lens is not None
        return _split_extend_seqs(extend_lens)
    elif forward_mode.is_target_verify() or forward_mode.is_decode():
        assert token_num_per_seq is not None
        return (num_tokens // token_num_per_seq) // 2
    elif forward_mode.is_idle() or forward_mode.is_prebuilt():
        assert num_tokens == 0
        return 0
    else:
        raise NotImplementedError()
​
​
def compute_split_token_index(
    split_seq_index: int,
    forward_mode: "ForwardMode",
    extend_seq_lens: Optional[Sequence[int]],
    token_num_per_seq: Optional[int],
) -> int:
    # 同样处理 MIXED 模式
    if forward_mode == ForwardMode.EXTEND or forward_mode == ForwardMode.MIXED:
        assert extend_seq_lens is not None
        if _is_two_chunk_split_enabled(extend_seq_lens):
            return sum(extend_seq_lens) // 2
        return sum(extend_seq_lens[:split_seq_index])
    elif forward_mode.is_target_verify() or forward_mode.is_decode():
        assert token_num_per_seq is not None
        return split_seq_index * token_num_per_seq
    elif forward_mode.is_idle():
        assert split_seq_index == 0
        return 0
    else:
        raise NotImplementedError
test/registered/distributed/test_dp_attention.py test-coverage

新增回归测试类 `TestDPAttentionMixedChunk`,验证 DP attention + mixed chunk 组合的正确性

class TestDPAttentionMixedChunk(
    CustomTestCase,
    GSM8KMixin,
):
    # 设置 GSM8K 准确率阈值为 0.6,用于验证推理正确性
    gsm8k_accuracy_thres = 0.6
​
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=[
                "--trust-remote-code",
                "--tp", "2",
                "--enable-dp-attention",
                "--dp", "2",
                "--enable-mixed-chunk", # 之前会崩溃的选项
                "--chunked-prefill-size", "256", # 触发 chunked prefill
            ],
        )
​
    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

评论区精华

使用 in 运算符简化枚举比较 style

gemini-code-assist[bot] 建议将 `forward_mode == ForwardMode.EXTEND or forward_mode == ForwardMode.MIXED` 改为 `if forward_mode in (ForwardMode.EXTEND, ForwardMode.MIXED):` 以提高可读性。

结论:未采纳,但属于风格优化,不影响功能。 · 已解决

风险与影响

风险较低。变更仅在两处条件判断中添加 MIXED 枚举值匹配,逻辑路径与 EXTEND 一致,且通过 assert extend_lens is not None 保证前置条件。未改动的 OperationsStrategy.init_new_tbo 对 MIXED 仍会抛出 NotImplementedError,但该路径仅在 --enable-two-batch-overlap 生效时到达,而该组合仍被标记为不支持,因此无回归风险。

修复了 --enable-dp-attention--enable-mixed-chunk 的组合崩溃 bug,使 DP attention 用户可以使用 mixed chunk 功能,提升吞吐。影响范围限定于使用这两个选项的 DP attention 场景,且不涉及 Two Batch Overlap 路径。

无显著风险

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论