Prhub

#26000 [codex] Centralize Triton utility kernels

原始 PR 作者 BBuf 合并时间 2026-06-02 16:47 文件变更 29 提交数 12 评论 106 代码增减 +3681 / -3547

执行摘要

集中分散的 Triton 工具内核到 triton_ops 子包

PR body 说明:若干非 diffusion Triton 内核嵌入在大型后端/辅助文件中,此 PR 将这些内核集中到领域特定的 triton_ops 模块,作为纯导入重构,不改内核行为。

值得精读,了解 sglang 中 Triton 内核的组织方式。对于新贡献者,了解 triton_ops 布局有助于快速定位内核。

讨论亮点

仅有一条审核评论,来自 DarkSharpness:'We indeed should do this long before.'(我们早就应该做这件事了。)表明社区普遍认同这一重构方向。

实现拆解

  1. 在 python/sglang/srt/layers/attention/triton_ops/ 下新建 cache_ops.py、kv_indices.py、pad.py、rope_cache.py、metadata.py,将原在 attention/utils.py 和 flashattention_backend.py 中的 Triton 内核按功能搬入。
  2. 在 python/sglang/srt/mem_cache/triton_ops/ 下新建 mla_buffer.py、common.py,搬入 MLA KV buffer 操作和内存池相关内核。
  3. 在 python/sglang/srt/speculative/triton_ops/ 下新建 cache_locs.py、multi_layer_eagle.py,搬入推测解码缓存位置和多层 Eagle 内核。
  4. 原文件改为从新模块重新导入所有公开符号(as 别名保持命名一致),确保所有已有调用方无需修改。
  5. 调整导入关系,删除原文件中 triton、torch 导入(不再需要),并删除内联内核定义。
  6. 合并 main 分支时解决冲突,确保导入顺序正确。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/utils.py 注意力层 modified 8.84
python/sglang/srt/speculative/spec_utils.py 推测解码 modified 8.65
python/sglang/srt/layers/attention/flashattention_backend.py 注意力层 modified 8.65
python/sglang/srt/mem_cache/utils.py 内存缓存 modified 8.65
python/sglang/srt/speculative/multi_layer_eagle_utils.py 推测解码 modified 8.69

关键符号

create_flashinfer_kv_indices_triton create_flashmla_kv_indices_triton concat_and_cast_mha_k_triton pad_sequence_with_mask seqlens_expand_triton set_mla_kv_buffer_triton set_mla_kv_buffer_triton_fp8_quant get_mla_kv_buffer_triton normal_decode_set_metadata prepare_swa_spec_page_table_triton create_extend_after_decode_spec_info assign_req_to_token_pool assign_draft_cache_locs generate_draft_decode_kv_indices rotate_input_ids_triton assign_new_state_triton assign_hidden_states_pool_triton write_req_to_token_pool_triton get_last_loc_triton canonicalize_stride

关键源码片段

python/sglang/srt/layers/attention/utils.py core-logic

曾经内联定义 8 个 Triton 内核;修改后删除了所有内核定义,改为从 triton_ops 导入,是此次重构的核心示范文件。

# 文件 python/sglang/srt/layers/attention/utils.py ( 重构后版本 )
# 原内联的 Triton 内核已迁移至 triton_ops 子模块,此处仅导入并保持别名from sglang.jit_kernel.utils import is_arch_support_pdl
from sglang.srt.layers.attention.triton_ops.cache_ops import (
    concat_and_cast_mha_k_kernel as concat_and_cast_mha_k_kernel,
    concat_and_cast_mha_k_triton as concat_and_cast_mha_k_triton,
    launch_reshape_and_cache_flash as launch_reshape_and_cache_flash,
    reshape_and_cache_flash as reshape_and_cache_flash,
)
from sglang.srt.layers.attention.triton_ops.kv_indices import (
    create_flashinfer_kv_indices_triton as create_flashinfer_kv_indices_triton,
    create_flashmla_kv_indices_triton as create_flashmla_kv_indices_triton,
    get_num_page_per_block_flashmla as get_num_page_per_block_flashmla,
)
from sglang.srt.layers.attention.triton_ops.pad import (
    pad_sequence_with_mask as pad_sequence_with_mask,
    pad_sequence_with_mask_kernel as pad_sequence_with_mask_kernel,
    seqlens_expand_kernel as seqlens_expand_kernel,
    seqlens_expand_triton as seqlens_expand_triton,
)
from sglang.srt.layers.attention.triton_ops.rope_cache import (
    fused_qk_rope_reshape_and_cache as fused_qk_rope_reshape_and_cache,
)
from sglang.srt.utils import is_cuda_is_cuda = is_cuda()
if _is_cuda:
    from sglang.jit_kernel.concat_mla import concat_mla_absorb_q# 新增函数:处理因 num_kv_heads=1 产生的退化 stride,
# 避免 flashinfer TMA 描述校验失败(详见 issue #2232)
def canonicalize_stride(tensor: torch.Tensor) -> torch.Tensor:
    sizes = tensor.size()
    strides = tensor.stride()
    ndim = tensor.dim()
    need_fix = any(
        sizes[i] == 1 and strides[i] == strides[i + 1] for i in range(ndim - 1)
    )
    if not need_fix:
        return tensor
    new_strides = [0] * ndim
    new_strides[-1] = 1
    for i in range(ndim - 2, -1, -1):
        new_strides[i] = max(new_strides[i + 1] * sizes[i + 1], strides[i])
    return tensor.as_strided(sizes, new_strides)
python/sglang/srt/mem_cache/utils.py core-logic

删除了 set_mla_kv_buffer_kernel 等 8 个内核定义,改为从 triton_ops.mla_buffer 导入,是内存缓存层的核心重构。

# 文件 python/sglang/srt/mem_cache/utils.py ( 重构后头部 )
# 原内联的 MLA KV Buffer Triton 内核已迁移至 triton_ops.mla_buffer,此处仅导入
from sglang.srt.mem_cache.triton_ops.mla_buffer import (
    get_mla_kv_buffer_kernel as get_mla_kv_buffer_kernel,
    get_mla_kv_buffer_triton as get_mla_kv_buffer_triton,
    set_mla_kv_buffer_fp8_quant_kernel as set_mla_kv_buffer_fp8_quant_kernel,
    set_mla_kv_buffer_kernel as set_mla_kv_buffer_kernel,
    set_mla_kv_buffer_triton as set_mla_kv_buffer_triton,
    set_mla_kv_buffer_triton_fp8_quant as set_mla_kv_buffer_triton_fp8_quant,
    set_mla_kv_scale_buffer_kernel as set_mla_kv_scale_buffer_kernel,
    set_mla_kv_scale_buffer_triton as set_mla_kv_scale_buffer_triton,
)# 以下保留的函数与内核无关,仅涉及哈希和策略工厂
from sglang.srt.environ import envs
from sglang.srt.mem_cache.evict_policy import (
    EvictionStrategy, FIFOStrategy, FILOStrategy, LFUStrategy,
    LRUStrategy, MRUStrategy, PriorityStrategy, SLRUStrategy,
)
_EVICTION_POLICY_FACTORIES: dict[str, Callable[[], EvictionStrategy]] = {
    "lru": LRUStrategy, "lfu": LFUStrategy, "fifo": FIFOStrategy,
    "mru": MRUStrategy, "filo": FILOStrategy, "priority": PriorityStrategy,
    "slru": SLRUStrategy,
}

评论区精华

代码组织结构讨论 设计

Reviewer DarkSharpness 评论:'We indeed should do this long before.'(我们早就应该做这件事了。)

结论:同意重构,无进一步修改要求。 · 已解决

风险与影响

主要风险在于遗留导入路径的完整性:尽管 PR 保留了别名导入,但大量文件被修改,可能遗漏某些内部调用或导致循环导入。此外,多次合并 main 可能引入不一致。但 CI 通过且无 review 异议,风险可控。

对用户无功能影响,纯代码重构。对团队维护有利:降低认知负荷,避免后期开发者在不相关文件中添加新内核。构建时间因新增文件略有变化,但可忽略。

向后兼容依赖遗留导入 大规模文件修改可能引入遗漏

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论