执行摘要
- 一句话:融合KDA的gate+cumsum操作并重用chunk index,提升内核性能2.2-2.65倍和端到端吞吐量6-11%。
- 推荐动作:该PR值得精读,特别是
kda_gate_chunk_cumsum内核的实现展示了如何通过融合计算减少内存往返,以及chunk index重用优化避免了重复预处理。关注设计权衡(如简化路径、测试覆盖)和性能提升技巧。
功能与动机
优化KDA内核性能,减少内存流量和重复计算。PR body中说明:'Optimize KDA kernel with fusing gate+cumsum and reusing chunk index.',旨在通过融合操作消除一个内核启动和一次全局内存往返(HBM流量从136MB降至68MB),并通过重用chunk index避免在varlen模式下5-6次冗余的prepare_chunk_indices调用,从而提升整体效率。
实现拆解
- 新增融合内核kda_gate_chunk_cumsum:在
python/sglang/srt/layers/attention/fla/kda.py中新增Triton内核函数kda_gate_chunk_cumsum和kda_gate_chunk_cumsum_vector_kernel,整合gate激活(公式-exp(A_log) * softplus(raw_g + dt_bias))和chunk-local cumulative sum,减少内存往返。内核使用autotune配置(BS_LIST基于共享内存检测优化)。
- 修改KDA主函数以使用融合内核:在
kda.py的chunk_kda_fwd函数中,添加A_log、dt_bias等参数,根据是否提供A_log选择调用kda_gate_chunk_cumsum(融合路径)或chunk_local_cumsum(传统路径),并预计算chunk_indices传递给下游函数以避免重复计算。
- 更新模型层适配:修改
python/sglang/srt/models/kimi_linear.py,移除对fused_kda_gate的导入和调用,在prefill模式中传递raw gate(未激活)给chunk_kda_fwd,让融合内核处理gate激活。
- 添加基准测试和单元测试:新增
benchmark/bench_linear_attention/bench_fused_gate_cumsum.py,包含make_inputs、run_ref、run_fused等函数,用于性能对比和验证;扩展test/registered/attention/test_kda_kernels.py,新增TestKDAGateChunkCumsum测试类,覆盖varlen、bias等多种场景,确保正确性(max_diff<2e-4)。
- 其他相关文件调整:更新
cumsum.py、chunk_delta_h.py、chunk_intra.py等文件,添加chunk_indices参数并优化逻辑,确保整个调用链支持index重用。
关键文件:
benchmark/bench_linear_attention/bench_fused_gate_cumsum.py(模块 基准测试;类别 source;类型 benchmark;符号 make_inputs, run_ref, run_fused, verify_correctness): 新增的基准测试文件,用于对比融合与分离路径的性能,验证优化效果,包含关键函数如run_ref和run_fused。
python/sglang/srt/layers/attention/fla/kda.py(模块 注意力内核;类别 source;类型 core-logic;符号 softplus_fwd, kda_gate_chunk_cumsum_vector_kernel, kda_gate_chunk_cumsum, chunk_kda_fwd): 核心实现文件,新增融合内核kda_gate_chunk_cumsum并修改chunk_kda_fwd以支持融合和chunk index重用,影响KDA主路径。
test/registered/attention/test_kda_kernels.py(模块 单元测试;类别 test;类型 test-coverage;符号 TestKDAGateChunkCumsum, _ref_gate_cumsum, _run_case, test_varlen_with_bias): 扩展的单元测试文件,新增TestKDAGateChunkCumsum类验证融合内核的正确性,覆盖多种场景(varlen、bias等)。
关键符号:kda_gate_chunk_cumsum, chunk_kda_fwd, softplus_fwd, chunk_local_cumsum, prepare_chunk_indices
关键源码片段
benchmark/bench_linear_attention/bench_fused_gate_cumsum.py
新增的基准测试文件,用于对比融合与分离路径的性能,验证优化效果,包含关键函数如run_ref和run_fused。
def run_ref(inp):
"""Separate path: torch gate activation -> chunk_local_cumsum."""
raw_g = inp["raw_g"] # [1, T, H, K]
A_log = inp["A_log"] # [H]
dt_bias = inp["dt_bias"] # [H*K]
cu_seqlens = inp["cu_seqlens"]
H, K = inp["H"], inp["K"]
# Step 1: gate activation using torch ops
g_float = raw_g.float()
if dt_bias is not None:
g_float = g_float + dt_bias.float().view(1, 1, H, K) # 添加 bias
g_activated = -torch.exp(
A_log.float().view(1, 1, H, 1)
) * torch.nn.functional.softplus(g_float) # 计算激活后的 gate
# Step 2: chunk-local cumsum
chunk_indices = prepare_chunk_indices(cu_seqlens, CHUNK_SIZE)
g_cumsum = chunk_local_cumsum(
g_activated,
chunk_size=CHUNK_SIZE,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
) # 执行 cumsum
return g_cumsum
def run_fused(inp):
"""Fused path: kda_gate_chunk_cumsum (single kernel)."""
raw_g = inp["raw_g"]
A_log = inp["A_log"]
dt_bias = inp["dt_bias"]
cu_seqlens = inp["cu_seqlens"]
chunk_indices = prepare_chunk_indices(cu_seqlens, CHUNK_SIZE)
g_cumsum = kda_gate_chunk_cumsum(
raw_g,
A_log=A_log,
chunk_size=CHUNK_SIZE,
dt_bias=dt_bias,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
) # 调用融合内核,一次性完成 gate 激活和 cumsum
return g_cumsum
python/sglang/srt/layers/attention/fla/kda.py
核心实现文件,新增融合内核kda_gate_chunk_cumsum并修改chunk_kda_fwd以支持融合和chunk index重用,影响KDA主路径。
def chunk_kda_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
initial_state_indices: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
A_log: Optional[torch.Tensor] = None, # 新增参数:gate 激活的 log-scale
dt_bias: Optional[torch.Tensor] = None, # 新增参数:gate 的 bias
lower_bound: Optional[float] = None,
):
chunk_size = 64
# 预计算 chunk_indices,避免下游函数重复计算
chunk_indices = (
prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
)
if A_log is not None:
# 融合路径:gate 激活 + cumsum
g = kda_gate_chunk_cumsum(
g,
A_log=A_log,
dt_bias=dt_bias,
lower_bound=lower_bound,
chunk_size=chunk_size,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices, # 传递预计算的 index
)
else:
# 传统路径:g 已由调用者激活,仅执行 cumsum
g = chunk_local_cumsum(
g,
chunk_size=chunk_size,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices, # 传递预计算的 index
)
# 后续 KDA 计算逻辑保持不变
...
评论区精华
- 设计简化:reviewer kaixih指出
chunk_kda_fwd中三条路径过于复杂,建议简化为基于A_log的条件分支。作者采纳,修改后逻辑更清晰。
- 测试覆盖:kaixih询问新内核的单元测试,作者回应'Added unit test.',在测试文件中新增了
TestKDAGateChunkCumsum类进行验证。
- 优化解释:reviewer rainj-me质疑预计算
chunk_indices是否冗余,作者解释这是有意优化,避免下游函数(如kda_gate_chunk_cumsum、chunk_local_cumsum)重复调用prepare_chunk_indices,从而减少开销。
- 代码清理:kaixih建议移除不再使用的
fused_kda_gate函数,作者回应'Removed.',并从kimi_linear.py中删除相关导入和调用。
- 简化chunk_kda_fwd中的逻辑路径 (design): 作者采纳建议,修改为if-else结构,使代码更清晰。
- 为新融合内核添加单元测试 (testing): 作者回应'Added unit test.',在测试文件中新增了TestKDAGateChunkCumsum类进行验证。
- chunk index重用优化的必要性 (performance): 作者解释这是有意优化,避免下游函数(如kda_gate_chunk_cumsum、chunk_local_cumsum)重复调用prepare_chunk_indices,减少开销。
风险与影响
- 风险:
- 正确性风险:融合内核可能引入数值误差,但测试显示max_diff=1.53e-04、rel_diff=4.13e-06,在可接受范围内;需确保所有边界情况(如varlen、无bias)被测试覆盖。
- 性能回归风险:新内核的autotune配置(BS_LIST)依赖共享内存检测,在AMD等平台需验证适应性;从提交看已调整以支持多平台。
- 兼容性风险:多个函数(如
recompute_w_u_fwd、chunk_gla_fwd_o_gk)添加了chunk_indices参数,需确保所有调用点更新,相关文件(如chunk_delta_h.py)已适配。
- 维护风险:新增内核和逻辑可能增加代码复杂度,但review中已简化设计并移除旧函数。
- 影响:
- 用户影响:端到端KDA预填充吞吐量提升6-11%,尤其在长序列(T_total≥4096)和大批次下收益显著,峰值吞吐量从8637 tok/ms增至9520 tok/ms,改善推理延迟。
- 系统影响:内核级内存流量减半,减少GPU内存带宽压力;chunk index重用降低CPU-GPU传输和计算开销,提升整体系统效率。
- 团队影响:代码更简洁(移除
fused_kda_gate),但需维护新内核;设计模式(融合操作、index重用)可为其他优化提供参考。
- 风险标记:核心路径变更, 数值精度风险, 接口变更
关联脉络
- PR #22544 [Score API] Add Multi-Item Scoring with pre-computed delimiter indices: 类似地采用了预计算索引优化来消除GPU扫描,提升性能,与本PR的chunk index重用策略有共通之处。
- PR #23315 Opt-in strip of thinking tokens from radix cache: 同样涉及缓存和性能优化,展示了仓库中对内存和计算效率的持续改进趋势。
参与讨论