执行摘要
- 一句话:MLX 后端实现解码异步重叠调度
- 推荐动作:值得精读。该 PR 展示了如何利用 MLX 的 lazy evaluation 特性设计高效的 GPU 流水线,是 Apple Silicon 推理性能优化的核心里程碑。
SchedulerMlxOverlapMixin 中的链式调度设计(两图链、链中断条件、async_eval 与 finalize 分离)具有较高参考价值。后续可以考虑扩展到 prefill/extend 链以及更鲁棒的 KV 缓存管理。
功能与动机
MLX 后端原有的实现中,每次解码步骤都会导致 CPU 与 GPU 同步,产生 GPU 空闲间隙(如 PR body 中的截图所示),限制了吞吐。该 PR 参考了 SGLang CUDA 版本的重叠调度设计(参见 https://www.lmsys.org/blog/2024-12-04-sglang-v0-4/),旨在通过异步评估消除这些空闲间隙,提升 Apple Silicon 上的生成性能。关联 Issue #22114 和 #22466。
实现拆解
-
新增 SchedulerMlxOverlapMixin(python/sglang/srt/hardware_backend/mlx/scheduler_mixin.py):包含 event_loop_overlap_mlx 主循环,维护两个 in-flight 的 MLX 计算图(pending_curr 和 pending_next),利用 mx.async_eval 实现 GPU 流水线。定义了 MlxPendingJob dataclass 持有未完成的 lazy 工作。
-
拆分 MlxModelRunner(python/sglang/srt/hardware_backend/mlx/model_runner.py)为 lazy API:新增 prefill_start/prefill_finalize、extend_start/extend_finalize、decode_batch_start/decode_batch_finalize 及 decode_batch_start_chained。新增 MlxPendingPrefill、MlxPendingExtend、MlxPendingDecode 数据类持有 lazy 结果。_cache_state_arrays 工具方法用于展平缓存数组,便于 mx.async_eval。
-
在 MlxTpModelWorker(python/sglang/srt/hardware_backend/mlx/tp_worker.py)中增加异步方法:async_forward_batch_generation_mlx 返回 lazy 结果,async_chained_decode_mlx 在上一 decode 的 lazy 输出上构建下一步计算图,finalize_mlx_result 阻塞等待并产生 GenerationBatchResult。提取 _cleanup_stale_rids 辅助方法。
-
调整调度器(python/sglang/srt/managers/scheduler.py):混入 SchedulerMlxOverlapMixin,根据 enable_overlap_mlx 标志在 dispatch_event_loop 中启用新的事件循环路径。分离 enable_overlap 为 CUDA 和 MLX 两个控制变量。init_overlap 中为 MLX 跳过 CUDA 流管理。
-
修改服务器参数(python/sglang/srt/server_args.py):调整 _handle_mps_backends 逻辑:仅当不使用 MLX 时才禁用 overlap schedule(因为 MLX 默认启用 overlap)。
-
修改输出处理器(python/sglang/srt/managers/scheduler_output_processor_mixin.py):process_batch_result_decode 中支持 MLX 路径的列表类型 next_token_ids,并在 finished 请求处理时同时检查 enable_overlap_mlx。
-
文档更新:在 apple_metal.mdx 和 apple_metal.md 中添加重叠调度功能说明和 benchmark 命令。
关键文件:
python/sglang/srt/hardware_backend/mlx/scheduler_mixin.py(模块 调度器混入;类别 source;类型 core-logic;符号 MlxPendingJob, SchedulerMlxOverlapMixin, event_loop_overlap_mlx, _finalize): 新增的调度器混入核心文件,包含 event_loop_overlap_mlx 主循环逻辑和 MlxPendingJob 数据类。
python/sglang/srt/hardware_backend/mlx/model_runner.py(模块 模型运行器;类别 source;类型 data-contract;符号 MlxPendingPrefill, MlxPendingExtend, MlxPendingDecode, _cache_state_arrays): 修改量最大的文件,拆分了 lazy API,定义了持有 lazy 结果的数据类。
python/sglang/srt/hardware_backend/mlx/tp_worker.py(模块 工作进程;类别 source;类型 dependency-wiring;符号 _cleanup_stale_rids, async_forward_batch_generation_mlx, _async_extend_batch, _cache_state): 增加了异步前向和后处理函数,构成 overlap scheduler 与 model runner 之间的桥梁。
关键符号:event_loop_overlap_mlx, _finalize, _launch_fresh, _launch_chained, prefill_start, prefill_finalize, extend_start, extend_finalize, decode_batch_start, decode_batch_finalize, decode_batch_start_chained, async_forward_batch_generation_mlx, async_chained_decode_mlx, finalize_mlx_result, _cache_state_arrays, _cleanup_stale_rids
关键源码片段
python/sglang/srt/hardware_backend/mlx/scheduler_mixin.py
新增的调度器混入核心文件,包含 event_loop_overlap_mlx 主循环逻辑和 MlxPendingJob 数据类。
"""MLX overlap scheduling mixin for the SGLang scheduler.
Provides ``event_loop_overlap_mlx``, which pipelines MLX forward
passes by keeping two in-flight lazy graphs queued on the GPU while
the scheduler runs its CPU-side bookkeeping on the tokens of the
older one.
"""
@dataclass
class MlxPendingJob:
"""Represents an unfinished MLX forward pass queued on the GPU.
``lazy_tokens``: mlx.array of token IDs, not yet evaluated.
``prefills``/``extends``/``decode``: per-mode state for finalization.
``mode``: ``"decode"``, ``"extend"``, or ``"idle"``.
``batch_copy``: snapshot of ``ScheduleBatch`` at launch time,
decoupled from the live batch to avoid races.
"""
lazy_tokens: Optional[mx.array]
prefills: list["MlxPendingPrefill"]
extends: list["MlxPendingExtend"]
decode: Optional["MlxPendingDecode"]
mode: str
batch_copy: "ScheduleBatch"
reqs: List[Req]
class SchedulerMlxOverlapMixin:
"""Mixin that adds MLX overlap scheduling to :class:`Scheduler`."""
@DynamicGradMode()
def event_loop_overlap_mlx(self: "Scheduler"):
"""MLX overlap loop modelled on ``mlx_lm.generate.generate_step``.
At steady state we keep TWO in-flight MLX graphs:
``pending_curr`` (about to be finalized) and
``pending_next`` (built on top of ``pending_curr``'s lazy output,
already handed to ``mx.async_eval``).
"""
# Initialize state
pending_curr: Optional[MlxPendingJob] = None
pending_next: Optional[MlxPendingJob] = None
self.result_queue.clear()
while not self.is_shutdown:
# Finalize the previous step's pending_curr, if any
if pending_curr is not None:
self._finalize(pending_curr)
# If a chained next step exists, shift it to current
if pending_next is not None:
pending_curr = pending_next
pending_next = None
else:
# Otherwise schedule a new batch
batch = self.get_next_batch_to_run()
if batch is None:
pending_curr = None
continue
pending_curr = self._launch_fresh(batch)
# Decide whether we can chain a next decode step
if self._can_chain(pending_curr):
pending_next = self._try_launch_chained(pending_curr)
# Block on pending_curr tokens to feed into bookkeeping
# (the GPU is already running pending_next in the background)
_ = pending_curr.lazy_tokens.tolist()
python/sglang/srt/hardware_backend/mlx/model_runner.py
修改量最大的文件,拆分了 lazy API,定义了持有 lazy 结果的数据类。
@dataclass
class MlxPendingPrefill:
"""Lazy prefill state, finalized after ``mx.eval``.
``cache`` is per-layer ``ContiguousKVCache`` list for commit.
"""
lazy_token: mx.array
cache: list # list[ContiguousKVCache]
req_id: str
full_token_ids: list[int]
req_pool_idx: int
synced_offset: int
@dataclass
class MlxPendingExtend:
"""Lazy chunked-prefill-continuation state for an existing request.
Uses the request's existing per-layer cache.
"""
lazy_token: mx.array
req_id: str
new_token_ids: list[int]
new_synced_offset: int
@dataclass
class MlxPendingDecode:
"""Lazy decode state for a batch.
``caches``: per-request list of per-layer ``ContiguousKVCache``
references that the attention wrapper writes into.
"""
lazy_tokens: mx.array
req_ids: list[str]
caches: list # list[list[ContiguousKVCache]]
class MlxModelRunner:
# ... (existing fields)
def decode_batch_start(self, req_ids: list[str]) -> MlxPendingDecode:
"""Start a decode step without evaluating.
Builds the compute graph, writes KV caches in-place,
and returns lazy token output.
"""
# ... merge KV caches, run model forward, collect lazy tokens
# Return MlxPendingDecode without calling mx.eval
return MlxPendingDecode(
lazy_tokens=logits.argmax(-1),
req_ids=req_ids,
caches=merged_caches,
)
def decode_batch_start_chained(self, prev: MlxPendingDecode) -> MlxPendingDecode:
"""Launch next decode step on top of a still-lazy previous decode.
Reuses the same cache objects, so MLX tracks the dependency.
"""
# Build graph using prev's lazy (unevaluated) output as input
# and the same cache lists (already updated in-place).
return self.decode_batch_start(prev.req_ids)
def decode_batch_finalize(self, pending: MlxPendingDecode) -> list[int]:
"""Block on lazy tokens and return token IDs.
Evaluates tokens together with cache arrays to materialize writes.
"""
cache_arrays = [
arr for c_list in pending.caches for arr in self._cache_state_arrays(c_list)
]
mx.eval(pending.lazy_tokens, *cache_arrays)
return pending.lazy_tokens.tolist()
python/sglang/srt/hardware_backend/mlx/tp_worker.py
增加了异步前向和后处理函数,构成 overlap scheduler 与 model runner 之间的桥梁。
def async_forward_batch_generation_mlx(
self,
model_worker_batch: ModelWorkerBatch,
) -> tuple[
Union[mx.array, None],
list[MlxPendingPrefill],
list[MlxPendingExtend],
Optional[MlxPendingDecode],
str,
]:
"""Start an async (lazy) forward pass through the MLX model runner.
Returns (lazy_result, prefills, extends, decode, mode) without
blocking on the GPU. The caller can later call ``finalize_mlx_result``
to block and produce a ``GenerationBatchResult``.
"""
forward_mode = model_worker_batch.forward_mode
reqs = model_worker_batch.reqs
if forward_mode.is_idle():
return (None, [], [], None, "idle")
self._cleanup_stale_rids(forward_mode, {req.rid for req in reqs})
if forward_mode.is_extend():
# ... build lazy extend graphs, return pending state
pass
else:
# Decode: use decode_batch_start
pending_decode = self._mlx_runner.decode_batch_start(
[req.rid for req in reqs]
)
return (pending_decode.lazy_tokens, [], [], pending_decode, "decode")
def async_chained_decode_mlx(self, prev_decode: MlxPendingDecode) -> MlxPendingDecode:
"""Build the next decode step on top of a still-lazy previous decode.
Reuses the cache objects from prev_decode, so MLX tracks the
dependency graph. The caller should hand the result to
``mx.async_eval`` immediately.
"""
next_decode = self._mlx_runner.decode_batch_start_chained(prev_decode)
# Fire async evaluation: GPU will execute this step as soon as
# the previous step's dependencies are resolved.
mx.async_eval(next_decode.lazy_tokens)
return next_decode
def finalize_mlx_result(self, pending_job: MlxPendingJob) -> GenerationBatchResult:
"""Block on lazy tokens and produce a normal GenerationBatchResult.
Depending on pending_job.mode, calls prefill/extend/decode finalize.
"""
if pending_job.mode == "decode":
next_token_ids = self._mlx_runner.decode_batch_finalize(pending_job.decode)
# ... build GenerationBatchResult
elif pending_job.mode == "extend":
# ... merge prefills and extends
pass
# ... return GenerationBatchResult
评论区精华
- Mixin 设计:yeahdongcn 建议将 MLX 重叠逻辑抽取为独立的 mixin 类以减少对
scheduler.py 的侵入,得到采纳。
- enable_overlap 变量命名:yeahdongcn 提议将 CUDA 相关变量重命名为
enable_overlap_torch 或 enable_overlap_cuda,以清晰区分 MLX 和 CUDA 路径。最终保留 enable_overlap 作为通用标志,新增 enable_overlap_mlx 专门控制 MLX。
- 导入路径错误:yeahdongcn 指出
server_args.py 中存在错误的相对导入 from python.sglang.srt...,后修正为绝对导入。
- BatchedKVCacheManager 设计:alexnails 提出是否考虑使用持久的 KV 缓存管理器(类似 CUDA 的
BatchedKVCacheManager)以简化代码。changminbark 解释由于 MLX 的 lazy eval 和 per-request ContiguousKVCache 设计,当前实现更简单且避免了额外拷贝。
- 链中断与请求完成:alexnails 和 Kangyan-Zhou 讨论了当请求完成时链的行为,发现
process_batch_result_decode 中的 finished-request guard 没有检查 enable_overlap_mlx,导致重复释放 KV 缓存。最终通过添加 self.enable_overlap_mlx 检查修复。
- 性能优化前瞻:alexnails 建议将
mx.array(ctx.seq_lens) 缓存到 BatchedDecodeContext 中避免每层重复创建,changminbark 同意并放入后续 PR 中。
- Mixin 设计建议 (design): 采纳建议,创建了 SchedulerMlxOverlapMixin。
- enable_overlap 变量命名与分离 (design): 最终在 scheduler.py 中引入 enable_overlap_mlx,enable_overlap 同时用于 CUDA 和 MLX 的通用条件。
- 导入路径错误 (correctness): 修正为正确的导入路径。
- BatchedKVCacheManager 设计质疑 (design): 保持当前 per-request cache 设计,未引入 BatchedKVCacheManager。
- finished-request guard 缺少 MLX 检查 (correctness): 添加 self.enable_overlap_mlx 检查到 guard 条件中。
- seq_lens 缓存优化 (performance): changminbark 同意并将此优化放入后续 PR。
风险与影响
- 风险:
- 内存管理复杂性:MLX lazy evaluation 会累积未评估的计算图,需依赖
mx.clear_cache() 定期清理(_decode_step_ct 计数器触发的 mx.metal.clear_cache() 调用),若清理不当可能导致内存泄漏。
- 链中断导致浪费:当请求完成时,已启动的
pending_next 仍然被评估,多余的一个 token 被丢弃,造成约一步 decode 的计算浪费。
- 仅 decode-decode 链:预填充、扩展或 batch 组合变化会中断链,回退到标准路径,部分重叠无法覆盖。
- 缺少单元测试覆盖:当前没有针对
event_loop_overlap_mlx 的单元测试,仅依靠手动功能测试和 benchmark,回归风险较高。
- 与调度器核心架构耦合:mixinin 方式虽然减少了侵入,但仍需在
scheduler.py 的关键路径中添加条件判断,未来核心重构可能需要同步调整。
- 影响:影响范围:仅 Apple Silicon(macOS)上使用 MLX 后端的用户,其他后端(CUDA/ROCm/CPU)不受影响。影响程度:默认启用,用户可通过 --disable-overlap-schedule 或环境变量 SGLANG_USE_MLX=0 关闭。预期能显著提升连续解码吞吐(尤其是在批量较大时),但对预填充或混合 batch 场景提升有限。文档已更新用法说明。团队:维护需要了解 MLX lazy evaluation 和 overlap 调度机制的开发者。
- 风险标记:缺少测试覆盖, 内存管理复杂, 链中断导致计算浪费, 仅 decode-decode 链
关联脉络
- PR #21509 [Apple Silicon] RadixCache and ContiguousKVCache support: 本 PR 依赖于 #21509 引入的 RadixCache 和 ContiguousKVCache 重构,PR 作者在开发过程中多次合并该分支以适配新缓存架构。
参与讨论