Prhub

#25825 [Refactor] Pass PP start_layer via model constructor instead of forward_batch.token_to_kv_pool

原始 PR 作者 ch-wan 合并时间 2026-05-20 13:16 文件变更 9 提交数 7 评论 4 代码增减 +59 / -8

执行摘要

通过构造函数传递 PP start_layer 以解耦 ForwardBatch

目前多个模型 Attention 层通过 forward_batch.token_to_kv_pool.start_layer 判断当前层是否为 PP rank 第一层,但 start_layer 是模型初始化时由 make_layers 确定的静态配置,不会在 forward 之间变化。将其从 ForwardBatch 提升为构造参数,概念更清晰,也为后续移除 ForwardBatch 中更多静态信息提供模式。

该 PR 是典型的接口清洁重构,值得精读。展示了如何分步将静态配置从运行时对象剥离,并且带测试覆盖和连带 bug 修复。设计决策(使用构造函数参数而非全局单例或上下文)值得借鉴。

讨论亮点

Gemini Code Assist 自动审查提供了三处反馈:

  • (1) ascend_backend.py 中 _cp_allgather_and_save_kv_npu 函数误用 self,可能导致 NameError,建议使用全局 pool getter(该文件不属本 PR 变更范围)。
  • (2) qwen3.py 的 Qwen3Model 需更新以计算 pp_start_layer 并传递给 make_layers(实际 PR 已做对应修改)。
  • (3) base_attn_backend.py 中建议将 pool 属性初始化为 None 避免 AttributeError
    其中 (2) 已通过提交解决,(1) 和 (3) 可作为后续改进的参考。

实现拆解

  1. 在 llama.py、glm4_moe.py、qwen3.py、qwen3_moe.py、qwen2.py、qwen2_moe.py 的 Attention 类和 DecoderLayer 类中添加 start_layer 参数(默认0),在 init 中保存为 self.start_layer。
  2. 在各模型的 init 中,调用 get_pp_indices 计算当前 PP rank 的起始层号 pp_start_layer,通过 make_layers 的 lambda 传递给每个 DecoderLayer。
  3. 在 Attention 的 forward_prepare_npu(以及 GLM4 的 forward_prepare)中,将条件判断从 self.attn.layer_id == forward_batch.token_to_kv_pool.start_layer 改为 self.attn.layer_id == self.start_layer。
  4. 修复 EAGLE 子类:llama_eagle.py、llama_eagle3.py、qwen2_eagle.py 中原本使用 positional args 调用 super().init,因新参数插入导致参数错位;改为 keyword args 确保 quant_config 正确传递。
    测试:在 B200 集群上执行了 Llama-3.1-8B、Qwen3-8B、Qwen3-30B-A3B、GLM-4.5-Air-FP8 的 PP 一致性测试,全部通过。无性能测试必要。
文件 模块 状态 重要度
python/sglang/srt/models/llama.py 模型层 modified 6.46
python/sglang/srt/models/glm4_moe.py 模型层 modified 6.3
python/sglang/srt/models/qwen3.py 模型层 modified 5.94
python/sglang/srt/models/qwen3_moe.py 模型层 modified 5.94
python/sglang/srt/models/qwen2.py 模型层 modified 5.87
python/sglang/srt/models/qwen2_moe.py 模型层 modified 5.87
python/sglang/srt/models/llama_eagle.py 模型层 modified 4.89
python/sglang/srt/models/llama_eagle3.py 模型层 modified 4.89
python/sglang/srt/models/qwen2_eagle.py 模型层 modified 4.89

关键符号

LlamaAttention.__init__ LlamaAttention.forward_prepare_npu Glm4MoeAttention.__init__ Glm4MoeAttention.forward_prepare Qwen3Attention.__init__ Qwen3Attention.forward_prepare_npu Qwen3MoeAttention.__init__ Qwen3MoeAttention.forward_prepare_npu Qwen2Attention.__init__ Qwen2DecoderLayer.__init__ Qwen2MoeDecoderLayer.__init__ LlamaEagleModel.__init__ LlamaEagle3Model.__init__ Qwen2EagleModel.__init__

关键源码片段

python/sglang/srt/models/llama.py data-contract

核心修改文件之一;展示了从 Attention、DecoderLayer 到模型类的完整 start_layer 传递链,并修改了 forward_prepare_npu 中的条件判断。

# 文件 : python/sglang/srt/models/llama.py
# 关键变更:从 ForwardBatch 中提取 start_layer 到构造参数class LlamaAttention(nn.Module):
    """Attention 层,通过构造参数接收 start_layer 而非从 forward_batch 获取。"""
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        layer_id: int = 0,
        start_layer: int = 0, # <-- 新增参数 : 当前 PP rank 起始层号
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        rope_is_neox_style: bool = True,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        bias: bool = False,
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.start_layer = start_layer # 保存到实例,后续 forward 直接使用
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        ...
​
    def forward_prepare_npu(self, positions, hidden_states, forward_batch):
        qkv, _ = self.qkv_proj(hidden_states)
        # 原判断 : self.attn.layer_id == forward_batch.token_to_kv_pool.start_layer
        # 现改为 : 直接对比 self.start_layer(类初始化时已确定)
        if self.attn.layer_id == self.start_layer:
            self.rotary_emb.get_cos_sin_with_position(positions)
        q, k, v = split_qkv_rmsnorm_rope(
            qkv,
            self.rotary_emb.position_sin,
            self.rotary_emb.position_cos,
            self.q_size, self.kv_size, self.head_dim,
            eps=self.q_norm.variance_epsilon,
            q_weight=self.q_norm.weight,
            k_weight=self.k_norm.weight,
            q_bias=getattr(self.q_norm, "bias", None),
            k_bias=getattr(self.k_norm, "bias", None),
        )
        return q, k, v# 在模型类 (LlamaModel) 的 __init__ 中,利用 get_pp_indices 计算 pp_start_layer
# 并通过 make_layers 的 lambda 传递给每一层from sglang.srt.distributed import get_pp_indicesclass LlamaModel(nn.Module):
    def __init__(self, config, ...):
        super().__init__()
        ...
        # 通过 get_pp_indices 获取当前 PP rank 负责的首层索引
        pp_start_layer, _ = get_pp_indices(
            config.num_hidden_layers,
            self.pp_group.rank_in_group,
            self.pp_group.world_size,
        )
        self.layers, self.start_layer, self.end_layer = make_layers(
            config.num_hidden_layers,
            lambda idx, prefix: LlamaDecoderLayer(
                config=config,
                quant_config=quant_config,
                layer_id=idx,
                start_layer=pp_start_layer, # <-- 每一层都拿到相同的 pp_start_layer
                prefix=prefix,
            ),
            pp_rank=self.pp_group.rank_in_group,
            pp_size=self.pp_group.world_size,
        )
python/sglang/srt/models/glm4_moe.py data-contract

GLM4-MoE 模型的 Attention 层和 DecoderLayer 同样需要 start_layer 迁移,并修改 forward_prepare 中的条件判断。

# 文件 : python/sglang/srt/models/glm4_moe.py
# 关键变更:从 ForwardBatch 中提取 start_layer 到构造参数class Glm4MoeAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        layer_id: int = 0,
        start_layer: int = 0, # <-- 新增参数 : 当前 PP rank 起始层号
        rope_theta: float = 1000000,
        partial_rotary_factor: float = 0.5,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        head_dim: Optional[int] = None,
        rms_norm_eps: float = 1e-05,
        attention_bias: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        use_qk_norm: bool = False,
        prefix: str = "",
        alt_stream: Optional[torch.cuda.Stream] = None,
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.start_layer = start_layer # 直接存储,不再依赖 forward_batch
        attn_tp_rank = get_attention_tp_rank()
        ...
​
    def forward_prepare(self, positions, hidden_states, forward_batch):
        qkv, _ = self.qkv_proj(hidden_states)
        # 原来的判断使用了 forward_batch.token_to_kv_pool.start_layer,
        # 现在改为使用实例属性 self.start_layer
        if self.attn.layer_id == self.start_layer:
            self.rotary_emb.get_cos_sin_with_position(positions)
        ...# 在模型类 (Glm4MoeModel) 的 __init__ 中,利用 get_pp_indices 计算 pp_start_layer
from sglang.srt.distributed import get_pp_indices# 在 Glm4MoeModel ( 具体类名以源码为准 ) 的 __init__ 内 :
pp_start_layer, _ = get_pp_indices(
    config.num_hidden_layers,
    self.pp_group.rank_in_group,
    self.pp_group.world_size,
)
self.layers, self.start_layer, self.end_layer = make_layers(
    config.num_hidden_layers,
    lambda idx, prefix: Glm4MoeDecoderLayer(
        layer_id=idx,
        start_layer=pp_start_layer, # <-- 传递
        config=config,
        quant_config=quant_config,
        prefix=prefix,
    ),
    pp_rank=self.pp_group.rank_in_group,
    pp_size=self.pp_group.world_size,
)

评论区精华

ascend_backend.py 中 NameError 风险 正确性

reviewer 指出函数 _cp_allgather_and_save_kv_npu 使用 self.token_to_kv_pool 但该函数是模块级函数,无 self,会引发 NameError。建议使用全局池 getter。

结论:该文件不属本 PR 变更范围,但反馈可作为后续改进参考。 · unresolved

Qwen3Model 中缺少 get_pp_indices 计算 正确性

reviewer 指出 Qwen3Model 未更新计算 pp_start_layer 并传递给 make_layers,会导致 start_layer 始终为 0。

结论:实际 PR 已包含对应修改(提交 ac113dad),reviewer 可能基于旧 diff。 · 已解决

base_attn_backend.py 属性懒初始化建议 设计

reviewer 建议将 pool 属性初始化为 None 以提高健壮性,避免 model_runner 为 None 时 AttributeError。

结论:PR 未涉及该文件,但建议合理。 · unresolved

风险与影响

主要风险是迁移不完整导致某些 PP 场景下 start_layer 始终为0,但 PR 已逐一模型修改并通过 PP 一致性测试覆盖。EAGLE 子类参数错位已在最后提交用 keyword args 修复,但未来新增类似子类时需注意参数位置。由于是纯重构,无运行时行为变化,回归风险低。无性能风险。

对用户无直接影响,功能一致。对开发者:减少了 ForwardBatch 的语义污染,明确了静态配置的传递路径;为后续从 ForwardBatch 移除更多类似于 token_to_kv_pool 的引用提供了参考模式。影响范围覆盖主流 Transformer 模型族,包括 LLaMA、Qwen2/3、Qwen2-MoE、Qwen3-MoE、GLM4-MoE。

多模型文件改造 EAGLE 子类参数顺序风险(已修复) 需确保所有调用链更新

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论