Prhub

#27114 [Bugfix] Restore overridden HF config fields and support index_skip_topk_offset for DSA topk sharing

原始 PR 作者 zRzRzRzRzRzRzR 合并时间 2026-06-06 13:26 文件变更 8 提交数 7 评论 28 代码增减 +100 / -5

执行摘要

修复 DSA 配置覆盖问题并支持 index_skip_topk_offset

GlmMoeDsaConfig drops/clobbers raw checkpoint fields the DSA path needs (qk_rope_head_dim, index_topk_freq), so we re-read them from config.json and restore. Fixed upstream by transformers PR #46338 — this workaround can be removed once SGLang requires transformers >= 5.10. Also adds handling for the index_skip_topk_offset config: when set, skip_topk / next_skip_topk are computed relative to the offset instead of layer 1.

本 PR 值得精读,特别是对 DSA 注意力机制、推测解码顶层索引管理、以及大规模模型服务配置兼容性感兴趣的开发者。关键设计决策包括:如何安全地跨 MTP 步骤重用 topk 索引、skip_topk 门控的精确语义、以及 TBO 与索引共享的不兼容性处理。建议在部署启用 index_topk_sharing 的模型时关注此变更。

讨论亮点

核心讨论集中在 reuse_mtp_topk_indices 字段的默认值和 TBO 兼容性。JustinTong0323 指出该字段默认为 False 而非 None,导致 TboForwardBatchPreparer.filter_batch 中的 getattr(batch, name) is not None 验证通过并抛出异常,使调度器在初始化时崩溃。zRzRzRzRzRzRzR 随后通过在 two_batch_overlap.py 中将该字段加入 output_dict 修复。其他设计权衡包括 skip_topk 门控为何在 MTP 之外必须严格避免 prev_topk_indices is None 的降级路径。

实现拆解

  1. 恢复 HF 配置字段:在 HfModelConfigParser.parse() 中,当检测到架构为 GlmMoeDsaForCausalLM 时,使用 PretrainedConfig.get_config_dict 重新读取原始 config.json,将 qk_rope_head_dimindex_topk_freq 恢复回 config 对象,并重新计算 qk_head_dim
  2. 支持 index_skip_topk_offset:在 DeepseekV2Attention.__init__ 中,当未设置 index_topk_pattern 但设置了 index_skip_topk_offset 时,根据偏移量计算 skip_topknext_skip_topk,使得 topk 共享能灵活指定起始层。
  3. 修复 skip_topk 门控条件:在 forward_mla.py 的两个分支中,将 if not self.skip_topk or prev_topk_indices is None 收紧为 if not self.skip_topk or (self.is_nextn and prev_topk_indices is None),防止在没有先前 topk 索引时错误地使用未初始化的 indexer。
  4. 支持 MTP 步骤间 topk 索引重用:在 eagle_worker.py 中添加 index_share_for_mtp_iteration 逻辑(仅 topk==1 时安全),跨多个 draft 步骤传递 topk 索引。deepseek_nextn.py 相应读取并写入 forward_batch.topk_indices
  5. 修复 TBO 兼容性:在 server_args.py 中增加检查,当启用 two-batch-overlap 且 DSA topk 共享生效时抛出错误(TBO 路径不支持跨层索引传播)。同时修复 forward_batch_info.pyreuse_mtp_topk_indices 字段默认值为 False 导致 TBO filter_batch 验证崩溃的问题,在 two_batch_overlap.py 中将其加入 pass-through 列表。
文件 模块 状态 重要度
python/sglang/srt/utils/hf_transformers/config.py 配置解析 modified 6.45
python/sglang/srt/models/deepseek_v2.py MLA 注意力 modified 6.68
python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py MLA 注意力 modified 6.19
python/sglang/srt/server_args.py 服务参数 modified 6.27
python/sglang/srt/speculative/eagle_worker.py 推测解码 modified 6.03
python/sglang/srt/models/deepseek_nextn.py NextN 模型 modified 6.02
python/sglang/srt/model_executor/forward_batch_info.py 前向批次 modified 5.46
python/sglang/srt/batch_overlap/two_batch_overlap.py 批处理重叠 modified 4.3

关键符号

HfModelConfigParser.parse DeepseekV2Attention.__init__ forward_absorb_prepare EagleWorker.draft_forward DeepseekNextN.forward TboForwardBatchPreparer.filter_batch

关键源码片段

python/sglang/srt/utils/hf_transformers/config.py dependency-wiring

核心修复:从 raw config 恢复被 GlmMoeDsaConfig 覆盖的 HF 字段,是 DSA 正确运行的前提。

class HfModelConfigParser(ModelConfigParserBase):
    # ... 其他初始化 ...
​
    def parse(self, model, trust_remote_code, revision=None, **kwargs):
        config = AutoConfig.from_pretrained(model, ...)
​
        # --- GlmMoeDsa 配置恢复:上游 transformers 5.10 后可以移除 ---
        if (
            config.architectures is not None
            and config.architectures[0] == "GlmMoeDsaForCausalLM"
        ):
            # GlmMoeDsaConfig 丢弃了 DSA 需要的原始字段
            # 重新从 config.json 读取并注入
            from transformers import PretrainedConfig
            raw_config, _ = PretrainedConfig.get_config_dict(model, revision=revision)
            for key in ("qk_rope_head_dim", "index_topk_freq"):
                if key in raw_config:
                    setattr(config, key, raw_config[key])
            # 重新计算 qk_head_dim,因为 qk_rope_head_dim 可能被覆盖
            if hasattr(config, "qk_head_dim") and hasattr(config, "qk_nope_head_dim"):
                config.qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
        # --- 恢复结束 ---
​
        # 后续处理 ...
python/sglang/srt/models/deepseek_v2.py data-contract

实现 index_skip_topk_offset 计算逻辑,调整 is_nextn 分支的 skip_topk 默认值(从 False 改为 True)。

class DeepseekV2Attention(nn.Module):
    def __init__(self, config, hidden_size, num_heads, ..., is_nextn=False, ...):
        super().__init__()
        # ...
        self.is_nextn = is_nextn # 新增:标记是否为 nextn 层
​
        self.skip_topk = None
        self.next_skip_topk = None
        if self.use_dsa:
            # ... indexer 初始化 ...
            if is_nextn:
                # nextn 层没有独立 indexer 权重,强制跳过并重用前一层索引
                self.skip_topk = True
                self.next_skip_topk = True
            else:
                self.index_topk_freq = getattr(config, "index_topk_freq", 1)
                self.index_topk_pattern = getattr(config, "index_topk_pattern", None)
                # 新增:index_skip_topk_offset 允许从指定偏移量开始计算
                self.index_skip_topk_offset = getattr(
                    config, "index_skip_topk_offset", None
                )
                if (
                    self.index_topk_pattern is None
                    and self.index_skip_topk_offset is not None
                ):
                    assert self.index_skip_topk_offset > 0, (
                        "index_skip_topk_offset must be positive; offset <= 0 "
                        "marks layer 0 as skip_topk with no prior topk to reuse"
                    )
                    # 相对于 offset 计算:layer_id 减去 offset 后再模 freq
                    self.skip_topk = (
                        max(layer_id - self.index_skip_topk_offset + 1, 0)
                        % self.index_topk_freq != 0
                    )
                    self.next_skip_topk = (
                        max(layer_id - self.index_skip_topk_offset + 2, 0)
                        % self.index_topk_freq != 0
                    )
                elif self.index_topk_pattern is None:
                    # 原始逻辑:从 layer 1 开始
                    self.skip_topk = max(layer_id - 1, 0) % self.index_topk_freq != 0
                    self.next_skip_topk = layer_id % self.index_topk_freq != 0
                else:
                    # 图案模式处理 ...
python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py data-contract

修复 skip_topk 时 indexer 调用的门控,避免在无权重时错误使用 indexer。

def forward_absorb_prepare(...):
    # ... 其他代码 ...
    # 非 alt_stream 分支中的 indexer 调用
    else:
        k_nope = k_nope.unsqueeze(1)
        q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        if q_lora is not None:
            # 关键门控:共享层没有 indexer 权重,
            # 当 skip_topk 为 True 时即使 prev_topk_indices 是 None
            # 也不能回退到计算模式(可能导致未初始化的 indexer)。
            # 唯一例外是 nextn 层,它有独立的 indexer 权重。
            if not self.skip_topk or (
                self.is_nextn and prev_topk_indices is None
            ):
                topk_indices = self.indexer(
                    x=hidden_states,
                    q_lora=q_lora,
                    positions=positions,
                    forward_batch=forward_batch,
                    layer_id=self.layer_id,
                )
            else:
                topk_indices = maybe_capture_indexer_topk(
                    self.layer_id, prev_topk_indices
                )
    # 类似修改也应用于 alt_stream 分支

评论区精华

reuse_mtp_topk_indices 字段默认值导致 TBO 崩溃 正确性

JustinTong0323 指出 reuse_mtp_topk_indices 默认为 False 而非 None,导致 TBO filter_batch 验证抛出异常;zRzRzRzRzRzRzR 随后通过在 two_batch_overlap.py 中将其加入 pass-through 列表修复。

结论:将 reuse_mtp_topk_indices 加入 TBO filter_batch 的 passthrough copy 列表,使其正确传递。 · 已解决

风险与影响

  1. 配置恢复兼容性:从 config.json 直接读取字段可能忽略其他覆盖,但只在 GlmMoeDsaForCausalLM 架构下生效,且上游 transformers 已确认修复,风险有限。
  2. index_skip_topk_offset 正确性:新引入的计算逻辑相对于固定偏移,与频率模式、图案模式三者互斥,但若同时设置可能出错,现有断言覆盖条件。
  3. skip_topk 门控收紧:在 forward_mla.py 中,跳过 indexer 计算的条件从宽松变为严格,任何未成功传播顶层索引的路径都会导致稀疏注意力的索引无效,虽然目前已知路径(TBO、无重用)都会触发,但仍然可能存在隐藏路径。
  4. MTP topk 重用安全性:仅在 topk==1 时启用,因为 select_top_k_tokens 在多候选时会重新排序行。若后续 topk>1 被启用,索引会错位。
  5. TBO 兼容性:新增的 ValueError 可能在用户配置不当时停止启动,是主动安全检查,风险低。

直接影响所有使用 DSA 模式的 DeepSeek V3.2/GLM 5 模型部署,尤其是那些具有 index_topk_freq>1index_topk_pattern 的模型(即启用 topk 共享)。同时影响使用 MTP 推测解码且 topk==1 的场景,以及使用 two-batch-overlap 的场景(通过禁止不兼容配置或修复传递问题)。不影响非 DSA 模式或非 DeepSeek 模型。整体影响面集中在 DeepSeek 高性能推理路径上。

配置恢复依赖上游 skip_topk 门控变更 topk 重用条件限制 TBO 不兼容检查

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论