执行摘要
- 一句话:集中分散的 Triton 工具内核到 triton_ops 子包
- 推荐动作:值得精读,了解 sglang 中 Triton 内核的组织方式。对于新贡献者,了解 triton_ops 布局有助于快速定位内核。
功能与动机
PR body 说明:若干非 diffusion Triton 内核嵌入在大型后端/辅助文件中,此 PR 将这些内核集中到领域特定的 triton_ops 模块,作为纯导入重构,不改内核行为。
实现拆解
- 在 python/sglang/srt/layers/attention/triton_ops/ 下新建 cache_ops.py、kv_indices.py、pad.py、rope_cache.py、metadata.py,将原在 attention/utils.py 和 flashattention_backend.py 中的 Triton 内核按功能搬入。
- 在 python/sglang/srt/mem_cache/triton_ops/ 下新建 mla_buffer.py、common.py,搬入 MLA KV buffer 操作和内存池相关内核。
- 在 python/sglang/srt/speculative/triton_ops/ 下新建 cache_locs.py、multi_layer_eagle.py,搬入推测解码缓存位置和多层 Eagle 内核。
- 原文件改为从新模块重新导入所有公开符号(as 别名保持命名一致),确保所有已有调用方无需修改。
- 调整导入关系,删除原文件中 triton、torch 导入(不再需要),并删除内联内核定义。
- 合并 main 分支时解决冲突,确保导入顺序正确。
关键文件:
python/sglang/srt/layers/attention/utils.py(模块 注意力层;类别 source;类型 core-logic;符号 create_flashinfer_kv_indices_triton, get_num_page_per_block_flashmla, create_flashmla_kv_indices_triton, concat_and_cast_mha_k_kernel): 曾经内联定义 8 个 Triton 内核;修改后删除了所有内核定义,改为从 triton_ops 导入,是此次重构的核心示范文件。
python/sglang/srt/speculative/spec_utils.py(模块 推测解码;类别 source;类型 core-logic;符号 create_extend_after_decode_spec_info, assign_req_to_token_pool, assign_req_to_token_pool_func, assign_draft_cache_locs): 删除了 5 个内联 Triton 内核定义,改为从 speculative/triton_ops/cache_locs.py 导入,是推测解码模块的核心文件。
python/sglang/srt/layers/attention/flashattention_backend.py(模块 注意力层;类别 source;类型 core-logic;符号 _prepare_swa_spec_page_table_kernel, prepare_swa_spec_page_table_triton, _fused_metadata_kernel_general, _fused_metadata_kernel_ps1_no_swa): 删除了 _prepare_swa_spec_page_table_kernel 等 4 个内核定义,改为从 triton_ops.metadata 导入,是 FlashAttention 后端的关键修改。
python/sglang/srt/mem_cache/utils.py(模块 内存缓存;类别 source;类型 core-logic;符号 set_mla_kv_buffer_kernel, set_mla_kv_buffer_triton, set_mla_kv_buffer_fp8_quant_kernel, set_mla_kv_buffer_triton_fp8_quant): 删除了 set_mla_kv_buffer_kernel 等 8 个内核定义,改为从 triton_ops.mla_buffer 导入,是内存缓存层的核心重构。
python/sglang/srt/speculative/multi_layer_eagle_utils.py(模块 推测解码;类别 source;类型 core-logic;符号 rotate_input_ids_kernel, rotate_input_ids_triton, assign_new_state_kernel, assign_new_state_triton): 删除了 rotate_input_ids_kernel 等 7 个内核,改为从 triton_ops.multi_layer_eagle 导入,是多层 Eagle 推测解码的核心文件。
关键符号:create_flashinfer_kv_indices_triton, create_flashmla_kv_indices_triton, concat_and_cast_mha_k_triton, pad_sequence_with_mask, seqlens_expand_triton, set_mla_kv_buffer_triton, set_mla_kv_buffer_triton_fp8_quant, get_mla_kv_buffer_triton, normal_decode_set_metadata, prepare_swa_spec_page_table_triton, create_extend_after_decode_spec_info, assign_req_to_token_pool, assign_draft_cache_locs, generate_draft_decode_kv_indices, rotate_input_ids_triton, assign_new_state_triton, assign_hidden_states_pool_triton, write_req_to_token_pool_triton, get_last_loc_triton, canonicalize_stride
关键源码片段
python/sglang/srt/layers/attention/utils.py
曾经内联定义 8 个 Triton 内核;修改后删除了所有内核定义,改为从 triton_ops 导入,是此次重构的核心示范文件。
# 文件 python/sglang/srt/layers/attention/utils.py ( 重构后版本 )
# 原内联的 Triton 内核已迁移至 triton_ops 子模块,此处仅导入并保持别名
from sglang.jit_kernel.utils import is_arch_support_pdl
from sglang.srt.layers.attention.triton_ops.cache_ops import (
concat_and_cast_mha_k_kernel as concat_and_cast_mha_k_kernel,
concat_and_cast_mha_k_triton as concat_and_cast_mha_k_triton,
launch_reshape_and_cache_flash as launch_reshape_and_cache_flash,
reshape_and_cache_flash as reshape_and_cache_flash,
)
from sglang.srt.layers.attention.triton_ops.kv_indices import (
create_flashinfer_kv_indices_triton as create_flashinfer_kv_indices_triton,
create_flashmla_kv_indices_triton as create_flashmla_kv_indices_triton,
get_num_page_per_block_flashmla as get_num_page_per_block_flashmla,
)
from sglang.srt.layers.attention.triton_ops.pad import (
pad_sequence_with_mask as pad_sequence_with_mask,
pad_sequence_with_mask_kernel as pad_sequence_with_mask_kernel,
seqlens_expand_kernel as seqlens_expand_kernel,
seqlens_expand_triton as seqlens_expand_triton,
)
from sglang.srt.layers.attention.triton_ops.rope_cache import (
fused_qk_rope_reshape_and_cache as fused_qk_rope_reshape_and_cache,
)
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sglang.jit_kernel.concat_mla import concat_mla_absorb_q
# 新增函数:处理因 num_kv_heads=1 产生的退化 stride,
# 避免 flashinfer TMA 描述校验失败(详见 issue #2232)
def canonicalize_stride(tensor: torch.Tensor) -> torch.Tensor:
sizes = tensor.size()
strides = tensor.stride()
ndim = tensor.dim()
need_fix = any(
sizes[i] == 1 and strides[i] == strides[i + 1] for i in range(ndim - 1)
)
if not need_fix:
return tensor
new_strides = [0] * ndim
new_strides[-1] = 1
for i in range(ndim - 2, -1, -1):
new_strides[i] = max(new_strides[i + 1] * sizes[i + 1], strides[i])
return tensor.as_strided(sizes, new_strides)
python/sglang/srt/mem_cache/utils.py
删除了 set_mla_kv_buffer_kernel 等 8 个内核定义,改为从 triton_ops.mla_buffer 导入,是内存缓存层的核心重构。
# 文件 python/sglang/srt/mem_cache/utils.py ( 重构后头部 )
# 原内联的 MLA KV Buffer Triton 内核已迁移至 triton_ops.mla_buffer,此处仅导入
from sglang.srt.mem_cache.triton_ops.mla_buffer import (
get_mla_kv_buffer_kernel as get_mla_kv_buffer_kernel,
get_mla_kv_buffer_triton as get_mla_kv_buffer_triton,
set_mla_kv_buffer_fp8_quant_kernel as set_mla_kv_buffer_fp8_quant_kernel,
set_mla_kv_buffer_kernel as set_mla_kv_buffer_kernel,
set_mla_kv_buffer_triton as set_mla_kv_buffer_triton,
set_mla_kv_buffer_triton_fp8_quant as set_mla_kv_buffer_triton_fp8_quant,
set_mla_kv_scale_buffer_kernel as set_mla_kv_scale_buffer_kernel,
set_mla_kv_scale_buffer_triton as set_mla_kv_scale_buffer_triton,
)
# 以下保留的函数与内核无关,仅涉及哈希和策略工厂
from sglang.srt.environ import envs
from sglang.srt.mem_cache.evict_policy import (
EvictionStrategy, FIFOStrategy, FILOStrategy, LFUStrategy,
LRUStrategy, MRUStrategy, PriorityStrategy, SLRUStrategy,
)
_EVICTION_POLICY_FACTORIES: dict[str, Callable[[], EvictionStrategy]] = {
"lru": LRUStrategy, "lfu": LFUStrategy, "fifo": FIFOStrategy,
"mru": MRUStrategy, "filo": FILOStrategy, "priority": PriorityStrategy,
"slru": SLRUStrategy,
}
评论区精华
仅有一条审核评论,来自 DarkSharpness:'We indeed should do this long before.'(我们早就应该做这件事了。)表明社区普遍认同这一重构方向。
- 代码组织结构讨论 (design): 同意重构,无进一步修改要求。
风险与影响
- 风险:主要风险在于遗留导入路径的完整性:尽管 PR 保留了别名导入,但大量文件被修改,可能遗漏某些内部调用或导致循环导入。此外,多次合并 main 可能引入不一致。但 CI 通过且无 review 异议,风险可控。
- 影响:对用户无功能影响,纯代码重构。对团队维护有利:降低认知负荷,避免后期开发者在不相关文件中添加新内核。构建时间因新增文件略有变化,但可忽略。
- 风险标记:向后兼容依赖遗留导入, 大规模文件修改可能引入遗漏
关联脉络
参与讨论