Prhub

#24785 Fix reduce_scatterv producer contract for SUM_LEN

原始 PR 作者 YAMY1234 合并时间 2026-05-11 07:51 文件变更 1 提交数 1 评论 3 代码增减 +4 / -2

执行摘要

修复 SUM_LEN 模式下 reduce_scatterv 合约错误

关联 Issue #23554 报告 Kimi K2.6 在 DEP8 配置下产生垃圾输出(GSM8K 准确率仅 1.2%),而 DP8 正常工作。根本原因是 LayerCommunicator.should_use_reduce_scatter() 仅在 MAX_LEN 模式下对 _scatter_hidden_states 返回 true,但在 SUM_LEN 模式下,生产层仍然执行了 TP all-reduce,随后通信器又执行了 reduce_scatterv,导致隐藏状态被双重规约,引发长 prompt 精度回归。Kimi K2.6 由于 first_k_dense_replace=1 包含早期密集 MLP 层,因此暴露了此问题。

建议精读该 PR 以理解 DP reduce_scatterv 生产合约的关键设计思路。这是一个典型的生产合约 Bug,修复逻辑清晰但影响重大,值得作为分布式推理中通信合约设计的案例研究。

讨论亮点

无实质性讨论。reviewer mmangkad 和 b8zhong 均批准了该 PR,gemini-code-assist[bot] 的自动审查也没有提供反馈。PR body 中详细记录了复现步骤和修复前后的 GSM8K 准确率对比(从 0.56 提升至 0.975),b8zhong 在 issue 评论中确认了修复效果。

实现拆解

  1. 修改 should_use_reduce_scatter 方法python/sglang/srt/layers/communicator.py 第 700-707 行):将原先要求 forward_batch.dp_padding_mode.is_max_len() 才能返回 true 的条件拆分,新增对 should_use_dp_reduce_scatterv() 的检查。如果该函数返回 true,则直接返回 true,无需检查 dp_padding_mode
  2. 保留原有 is_max_len 分支:当 should_use_dp_reduce_scatterv() 为 false 时,仍按原有逻辑检查 is_max_len() 以兼容其他场景。
  3. 调整后的条件逻辑:只要通信器计划执行 DP reduce_scatterv(通过 should_use_dp_reduce_scatterv() 判断),生产层就会跳过内部的 TP all-reduce,从而避免双重规约。
文件 模块 状态 重要度
python/sglang/srt/layers/communicator.py 通信层 modified 6.16

关键符号

should_use_reduce_scatter

关键源码片段

python/sglang/srt/layers/communicator.py core-logic

核心修复文件,修改了 `should_use_reduce_scatter` 方法中的条件判断,增加了 `should_use_dp_reduce_scatterv()` 检查,使 `SUM_LEN` 模式下也能正确跳过生产层的 TP all-reduce。

# python/sglang/srt/layers/communicator.pydef should_use_reduce_scatter(self, forward_batch: ForwardBatch):
    if not self.allow_reduce_scatter:
        return False
    if (
        self._communicate_summable_tensor_pair_fn
        is CommunicateSummableTensorPairFn._scatter_hidden_states
    ):
        # 如果启用了 DP reduce_scatterv,则生产者应跳过 TP all-reduce
        # 以符合通信器的 reduce_scatterv 合约,避免双重规约。
        if should_use_dp_reduce_scatterv():
            return True
        # 原来只检查 MAX_LEN 模式,现在两种模式都覆盖
        if forward_batch.dp_padding_mode.is_max_len():
            return True
    if nsa_use_prefill_cp(forward_batch):
        return True
    if get_attn_tp_context().input_scattered and not self.is_last_layer:
        return True
    return False

评论区精华

修复 reduce_scatterv 生产合约 正确性

PR body 详细描述了问题根因:`should_use_reduce_scatter()` 只在 MAX_LEN 模式下对 `_scatter_hidden_states` 返回 true,SUM_LEN 模式下生产层仍执行 TP all-reduce,导致双重规约和精度退化。

结论:修改为在 `scatter_hidden_states` 分支中同时检查 `should_use_dp_reduce_scatterv()` 和 `is_max_len()`,修复 SUM_LEN 模式下的合约问题。 · 已解决

风险与影响

低风险。该 PR 修改仅 6 行,核心逻辑是 should_use_dp_reduce_scatterv() 的添加,该函数在其他地方已有使用,逻辑正确。可能的风险点包括:

  • should_use_dp_reduce_scatterv() 在某些不受支持的场景下被错误启用,可能导致生产层错误跳过 TP all-reduce。但该函数在其他路径(如 _scatter_hidden_states 自身)已正确使用,因此风险较低。
  • 未修改单元测试,但 PR body 提供了充分的 GSM8K 准确率验证和速度测试。

直接影响所有使用 DP reduce_scatterv 且在 SUM_LEN 模式下运行推理的场景,特别是包含密集 MLP 层的 MoE 模型(如 Kimi K2.6)。修复后,这些场景的精度将从几乎不可用恢复到正常水平。不影响 DP8 模式(不启用 reduce_scatterv)或 MAX_LEN 模式。用户无需更改配置或代码即可受益。

核心路径变更 缺少测试覆盖

关联 Issue

#23554 [Bug] Kimi K2.6 DEP8 produces garbage output but DP8 works fine

完整报告

参与讨论