Prhub

#25010 [Spec] Remove dead kernel params; fix stale comment in `trtllm_mla`

原始 PR 作者 hnyls2002 合并时间 2026-05-12 05:50 文件变更 4 提交数 2 评论 4 代码增减 +19 / -25

执行摘要

清理投机解码和注意力中的死参数与注释

该 PR 的动机是移除死亡参数和过时注释,以清理代码库。PR body 指出这些内容是 'Three independent cleanups in spec / attention scope',并将此 PR 描述为从 #24081 分离出的琐碎清理部分。移除未使用的参数可以减少混淆并简化接口。

值得快速审查并合并。这是一个教科书式的代码清理 PR:移除死亡代码、修正过时注释。对于关注代码健康的团队,值得鼓励此类 PR。此外,其中关于 Mamba state scatter kernel 参数命名的清理可作参考,展示了如何使测试变量名与生产代码对齐。

讨论亮点

该 PR 的 review 中没有实质性讨论或争议。所有评论均来自自动化 bot(gemini-code-assist 和 github-actions),主要涉及 CI 重跑。作者手动触发了相关测试(test_mamba_state_scatter_triton.pytest_kda_kernels.py)且均已通过。

实现拆解

  1. 移除 fla/kda.py 中的死亡参数 num_accepted_tokens:在 fused_recurrent_kda_fwd 函数签名和 kernel launch 调用中删除 num_accepted_tokens 参数。该参数在函数内部从未被读取(唯一定向引用被注释掉),且唯一调用者硬编码为 None。同时更新 fused_recurrent_kda 调用处。
  2. 移除 mamba_state_scatter_triton.py 中的死亡参数 total_requests:在 _fused_mamba_state_scatter_with_mask_kernel Triton kernel 签名和 fused_mamba_state_scatter_with_mask 调用中删除 total_requests 参数。该参数从未被 kernel 读取,因为 pid_req 来自 tl.program_id,边界检查使用 src_req_sizedst_req_size
  3. 修正 trtllm_mla_backend.py 中的过时注释:在 pad_draft_extend_query_kernel 中修正两条注释:cumsum_ptr 的注释从 'Cumulative sum of accept lengths' 改为 'Cumulative sum of sequence lengths',以及注释 'Load accept length for this batch' 改为 'Load sequence length for this batch'。这些变量实际上加载的是 seq_lens_q,而非接受长度。
  4. 测试配套更新:在 test_mamba_state_scatter_triton.py 中,将参考实现和 fused 实现中的参数名从 accepted_steps/request_number 改为 step_indices_raw/total_requests,以对齐生产 kernel 的命名。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/fla/kda.py 注意力 modified 4.98
python/sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py Mamba modified 4.7
python/sglang/srt/layers/attention/trtllm_mla_backend.py MLA modified 4.83
test/registered/unit/layers/test_mamba_state_scatter_triton.py 测试 modified 5.12

关键源码片段

python/sglang/srt/layers/attention/fla/kda.py core-logic

移除 `fused_recurrent_kda_fwd` 和 `fused_recurrent_kda` 中未使用的 `num_accepted_tokens` 参数。

# 变更后的 fused_recurrent_kda_fwd 签名:不再有 num_accepted_tokens 参数
def fused_recurrent_kda_fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    beta: torch.Tensor,
    scale: float,
    initial_state: torch.Tensor,
    inplace_final_state: bool = True,
    cu_seqlens: torch.LongTensor | None = None,
    # ssm_state_indices: torch.Tensor | None = None, # 仍保留但已注释
    use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    ...
    fused_recurrent_gated_delta_rule_fwd_kernel[grid](
        ...
        # num_accepted_tokens=num_accepted_tokens, # 此行已被移除
        ...
    )
python/sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py core-logic

从 `_fused_mamba_state_scatter_with_mask_kernel` 和 `fused_mamba_state_scatter_with_mask` 中移除未使用的 `total_requests` 参数。

# 变更后的 kernel 签名(@triton.jit):不再有 total_requests 参数
@triton.jit
def _fused_mamba_state_scatter_with_mask_kernel(
    src_ptr,
    dst_ptr,
    dst_indices_raw_ptr, # [total_requests] - state_indices_tensor
    step_indices_raw_ptr, # [total_requests] - accepted_steps / mamba_steps_to_track
    elem_per_entry: tl.constexpr, # 原 total_requests 参数被移除
    src_layer_stride,
    src_req_stride,
    src_step_stride,
    dst_layer_stride,
    dst_req_stride,
    src_req_size,
    src_step_size,
    dst_req_size,
    BLOCK_SIZE: tl.constexpr,
):
    # pid_req 来自 tl.program_id(0),而非参数;边界检查使用 src_req_size / dst_req_size
    pid_req = tl.program_id(0)
    ...

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

低风险。所有变更均为纯清理:移除死亡参数(函数内部未读取,调用者硬编码为 None)和修正注释。测试套件已通过,且无行为变更。然而,需注意:若未来有外部代码依赖被移除的参数(如通过 **kwargs 传递),则可能出现兼容性问题。但当前代码库中无此类使用。

直接影响范围小:仅修改 4 个文件,共减少 25 行代码。无用户可见变更。对系统性能无影响。对团队的影响是降低了代码复杂度,减少了未来维护者的认知负荷。

低风险清理

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论