Prhub

#22462 [PD][Bugfix] fix mamba cache capping

原始 PR 作者 yudian0504 合并时间 2026-04-30 10:57 文件变更 2 提交数 4 评论 7 代码增减 +19 / -5

执行摘要

修复 PD 分离部署中 Mamba 缓存池大小计算错误

PR body 指出:"In the PD disaggregation scenario, the decode node incorrectly evaluates the number of mamba slots per-req in extra_buffer mode." 即 decode 节点在 extra_buffer 模式下错误计算每个请求所需的 Mamba 槽位数,导致分配失败。

值得精读。本 PR 展示了在 PD 分离架构下处理 Mamba 缓存一致性的正确方法,特别是如何通过 server_args 进行 fallback 以及精确计算槽位需求。对于维护多节点推理系统的开发者有参考价值。

讨论亮点

Review 中主要讨论了日志信息的明确性。gemini-code-assist[bot] 建议将日志中的 size + pre_alloc_size 替换为 max_slots_needed 以避免混淆。该建议未被采纳(实际代码中日志已包含 max_slots_needed)。此外,hzh0425 在 issue 评论中总结:1)decode 侧当前不支持 radix tree,因此不能启用 extra_buffer,server_args 中的回退是必要的;2)未来 decode 侧将支持 radix tree,届时本 PR 的修复将是正确的。

实现拆解

  1. server_args.py 中禁用 decode 侧的 extra_buffer:当 disaggregation_mode == "decode" 时,如果启用了 enable_mamba_extra_buffer,则将其回退为 no_buffer,因为 decode 侧当前不支持 radix tree,无法正确使用 extra_buffer。添加警告日志。
  2. decode.py 中修正 Mamba 池大小计算:原逻辑在 mamba_size 非 None 时取 min(mamba_size, size + pre_alloc_size),在 None 时取 size + pre_alloc_size,没有考虑每个请求在 extra_buffer 下需要多个槽位(1个主槽位 + ping-pong 槽位)。新逻辑计算 slots_per_req = 1 + (mamba_ping_pong_track_buffer_size if enable_mamba_extra_buffer else 0),然后计算 max_slots_needed = (size + pre_alloc_size) * slots_per_req,最终取 max(mamba_size, max_slots_needed)max_slots_needed
  3. 日志信息调整:原警告日志提到 "capping effective_mamba_size",新日志更改为 "raising effective_mamba_size" 并包含 max_slots_neededslots_per_req 的详细信息,便于调试。
  4. 无测试配套变更:本次改动仅涉及两个源码文件,未增加或修改测试用例。
文件 模块 状态 重要度
python/sglang/srt/disaggregation/decode.py 分离推理 modified 6.68
python/sglang/srt/server_args.py 服务器配置 modified 5.75

关键符号

HybridMambaDecodeReqToTokenPool.__init__ ServerArgs._handle_pd_disaggregation

关键源码片段

python/sglang/srt/disaggregation/decode.py core-logic

核心修复文件,修改了 HybridMambaDecodeReqToTokenPool 的初始化逻辑,正确计算 Mamba 槽位需求。

# python/sglang/srt/disaggregation/decode.py 中的关键片段
class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):
​
    def __init__(
        self,
        size: int,
        # ... 其他参数
        enable_mamba_extra_buffer: bool,
        enable_overlap_schedule: bool,
        mamba_size: int = None,
    ):
        # 基类初始化
        DecodeReqToTokenPool.__init__(self, ...)
​
        self.mamba_ping_pong_track_buffer_size = 2 if enable_overlap_schedule else 1
        self.enable_mamba_extra_buffer = enable_mamba_extra_buffer
​
        # 修复:每个请求需要 1 个主 Mamba 槽位 + ping-pong 槽位(如果 extra_buffer 启用)
        # 本例中 ping-pong 槽位数取决于是否启用 overlap_schedule
        slots_per_req = 1 + (
            self.mamba_ping_pong_track_buffer_size if enable_mamba_extra_buffer else 0
        )
        # 最大并发请求数为 size + pre_alloc_size
        max_slots_needed = (size + pre_alloc_size) * slots_per_req
​
        if mamba_size is not None:
            # 原逻辑:取 min,现在改为取 max,确保不小于 real need
            effective_mamba_size = max(mamba_size, max_slots_needed)
            if mamba_size < max_slots_needed:
                logger.warning(
                    "mamba_size (%d) is less than decode side's max_slots_needed "
                    "(%d = %d reqs * %d slots/req), "
                    "raising effective_mamba_size to %d",
                    mamba_size, max_slots_needed,
                    size + pre_alloc_size, slots_per_req,
                    effective_mamba_size)
        else:
            effective_mamba_size = max_slots_needed
​
        # 使用计算出的 effective_mamba_size 初始化池
        self._init_mamba_pool(
            size=effective_mamba_size,
            mamba_spec_state_size=size + pre_alloc_size,
            # ... 其他参数
            enable_mamba_extra_buffer=self.enable_mamba_extra_buffer,
        )
python/sglang/srt/server_args.py core-logic

在 PD decode 侧自动禁用 mamba extra_buffer,防止因不支持的配置导致问题。

# python/sglang/srt/server_args.py 中的关键片段
    def _handle_pd_disaggregation(self):
        if self.disaggregation_mode == "decode":
            self.disable_radix_cache = True
            logger.warning("KV cache is forced as chunk cache for decode server")
            # 新增:decode 侧当前不支持 radix tree,因此强制禁用 extra_buffer
            if self.enable_mamba_extra_buffer():
                logger.warning(
                    "Mamba extra_buffer is disabled because decode disaggregation "
                    "currently forces chunk cache. Falling back to no_buffer."
                )
                self.mamba_scheduler_strategy = "no_buffer"
        elif self.disaggregation_mode == "prefill":
            # ... prefill 逻辑不变

评论区精华

日志信息明确性 设计

gemini-code-assist[bot] 建议日志中使用 max_slots_needed 而不是 size + pre_alloc_size 以便调试。

结论:最终代码中日志已经包含了 max_slots_needed 和 slots_per_req 的详细信息,建议已体现在最终实现中。 · 已解决

风险与影响

  1. 回归风险:修改了 Mamba 池大小的核心分配逻辑,可能影响非 PD 场景下的 Mamba 缓存行为。但改动范围小且逻辑在条件分支中,非 PD 场景不受影响。
  2. 性能影响:正确计算槽位数避免了分配不足或过度分配,对性能有正面作用。
  3. 兼容性:server_args 中 disable extra_buffer 的改动可能改变用户预期的行为,但已有日志警告。

影响范围:主要在 PD 分离部署且使用 Mamba 模型的场景(如带有 Mamba 层的混合模型)。正确性修复解决了缓存分配失败问题。影响程度:中,仅影响特定配置下的特定模型。用户影响:用户不会再遇到因 Mamba 槽位不足导致的运行时错误。团队影响:无。

核心路径变更 缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论