Prhub

#22434 [CPU][sgl-kernel] `extend_attention_cpu` and `flash_attn_varlen_func`: fix `nan` for large seq

sgl-project/sglang · 作者 chunyuan-w · 合并时间 2026-04-17 13:01

分析状态 已生成
文件变更 4提交数 10 · 评论 3
代码增减 +99 / -8
sgl-kernel bugfix cpu run-ci consistency

执行摘要

修复 CPU 内核中因果掩码逻辑错误,解决大序列输入时 NaN 问题。

Issue 20051报告了在CPU上使用BF16精度时,当prefill长度超过约4096 tokens时,GDN内核产生NaN值,导致torch.multinomial崩溃。PR body指出在extend_attention_cpu和flash_attn_varlen_func内核中,因果掩码逻辑错误导致大序列输入时出现NaN,具体表现为last_col为负或掩码条件不充分。

该PR值得精读,重点关注因果掩码条件从num_keys - n <= BLOCK_Nn + n_size - 1 > m的设计变更,这揭示了块状注意力中处理未来键的通用模式。工程师应学习如何通过钳位last_col避免越界写入,并在测试中覆盖边界情况。

讨论亮点

reviewer gemini-code-assist[bot] 指出测试辅助函数 _test_extend_attention_fixed_lens 与现有函数 _test_extend_attention_once 存在代码重复,建议重构以提升可维护性。作者在后续提交中通过提取公共逻辑解决了此问题,确保代码更清晰。

实现拆解

  1. 修改因果掩码条件:在sgl-kernel/csrc/cpu/extend.cppsgl-kernel/csrc/cpu/flash_attn.cpp中,将原有的掩码条件num_keys - n <= BLOCK_N更新为n + n_size - 1 > m,以正确识别需要掩码的键块(即最后一个键位置大于首个查询位置的块)。
  2. 钳位last_col到-1:在同一代码段中,添加last_col = std::max(last_col, -1),防止当last_col为负时fill_stub越界写入,确保整个行被掩码。
  3. 扩展测试覆盖:在test/srt/cpu/test_extend.pytest/srt/cpu/test_flash_attn.py中,新增测试用例test_extend_attention_large_seq_causal_masktest_flash_attn_large_seq_causal_mask,使用序列长度5000+验证修复。
  4. 测试辅助函数重构:根据review反馈,通过提交refactor ut to remove duplicated code简化测试代码,减少重复逻辑。
文件 模块 状态 重要度
sgl-kernel/csrc/cpu/extend.cpp 注意力内核 modified 6.44
sgl-kernel/csrc/cpu/flash_attn.cpp 注意力内核 modified 5.91
test/srt/cpu/test_extend.py 测试覆盖 modified 5.77
test/srt/cpu/test_flash_attn.py 测试覆盖 modified 5.7
sgl-kernel/csrc/cpu/extend.cpp core-logic

核心注意力扩展内核,修复因果掩码逻辑错误,防止大序列输入时 NaN 产生和越界写入。

// apply causal mask
// [Note] condition to apply causal mask.
// Mask any block whose last key (n + n_size - 1) is strictly after the first query position (m), i.e. n + n_size - 1 > m.
// 原条件 num_keys - n <= BLOCK_N 仅在最后一个n-block生效,但BLOCK_M=512, BLOCK_N=768时,首个n-block可能包含未来键。
if (n + n_size - 1 > m) {
  for (int row = 0; row < m_size; ++row) {
    int last_col = m + row - n;
    // [Note] mask the entire row if last_col < 0.
    // 当 n > m + row 时,该块所有键都是未来键,需要掩码整个行。
    // 钳位到-1,避免 last_col+1 <= 0 导致 fill_stub 越界写入。
    last_col = std::max(last_col, -1);
    // fill [last_col + 1, n_size) to -inf
    float* row_ptr = s_i + row * BLOCK_N;
    fill_stub(row_ptr + last_col + 1, -std::numeric_limits<float>::infinity(), n_size - last_col - 1);
  }
}

关键符号

extend_attention_kernel_impl flash_attn_kernel_impl flash_attn_varlen_kernel_impl

评论区精华

测试辅助函数代码重复 设计

reviewer gemini-code-assist[bot] 指出新 helper 函数 _test_extend_attention_fixed_lens 与现有函数 _test_extend_attention_once 存在大量逻辑重复,建议重构以提升可维护性。

结论:作者通过提交 'refactor ut to remove duplicated code' 重构了测试代码,提取公共逻辑,解决了重复问题。 · 已解决

风险与影响

修复涉及核心注意力计算路径,风险包括:回归风险(如果新掩码条件未覆盖所有边界情况,可能导致掩码不全或过度掩码)、性能影响(额外钳位操作可能轻微增加CPU开销,但可忽略)、兼容性问题(修改可能影响所有CPU上的注意力操作,但通过测试验证了正确性)。

对用户:解决了大序列输入时NaN崩溃问题,提升模型在CPU上的稳定性和长上下文处理能力。对系统:确保注意力计算正确性,避免因NaN传播导致的后续采样错误。对团队:提供了更健壮的内核实现,为后续CPU优化奠定基础。

核心路径变更 边界条件处理 测试覆盖增强

关联 Issue

#20051 [Bug] [CPU] GDN (chunk_gated_delta_rule_cpu) produces NaN with BF16 when prefill exceeds ~4096 tokens

完整报告

执行摘要

本PR修复了CPU上extend_attention_cpuflash_attn_varlen_func内核中因果掩码逻辑错误,解决了输入序列长度超过4096时出现的NaN崩溃问题。通过调整掩码条件和钳位last_col,避免了越界写入和掩码不全,并新增测试用例确保修复有效性。

功能与动机

为什么做:Issue 20051报告了在CPU上使用BF16精度时,当prefill长度超过约4096 tokens时,GDN内核产生NaN值,导致torch.multinomial崩溃。PR body进一步指出,在extend_attention_cpuflash_attn_varlen_func内核中,因果掩码逻辑错误导致大序列输入时出现NaN。具体地,当BLOCK_M=512BLOCK_N=768时,原有掩码条件num_keys - n <= BLOCK_N仅对最后一个键块生效,忽略了首个键块中的未来键,且last_col为负时引发越界写入。

实现拆解

  1. 修改因果掩码条件:在sgl-kernel/csrc/cpu/extend.cppsgl-kernel/csrc/cpu/flash_attn.cpp中,将掩码条件从num_keys - n <= BLOCK_N更新为n + n_size - 1 > m。这个新条件基于块中最后一个键的位置(n + n_size - 1)是否严格大于首个查询位置(m),从而正确识别需要掩码的块。
    cpp // 示例代码片段来自extend.cpp if (n + n_size - 1 > m) { for (int row = 0; row < m_size; ++row) { int last_col = m + row - n; last_col = std::max(last_col, -1); // 钳位到-1,避免负索引 float* row_ptr = s_i + row * BLOCK_N; fill_stub(row_ptr + last_col + 1, -inf, n_size - last_col - 1); } }
  2. 钳位last_col到-1:添加last_col = std::max(last_col, -1),防止当last_col为负时fill_stub写入非法内存地址,确保整个行被掩码为-inf
  3. 扩展测试覆盖:在test/srt/cpu/test_extend.py中,修改_test_extend_attention_once以支持固定长度参数,并新增test_extend_attention_large_seq_causal_mask测试序列长度5000。在test/srt/cpu/test_flash_attn.py中,类似地新增_test_flash_attn_large_seq_causal_mask_oncetest_flash_attn_large_seq_causal_mask,覆盖单序列和多序列场景。
  4. 测试代码重构:根据review反馈,作者通过提交重构测试代码,减少_test_extend_attention_once_test_extend_attention_fixed_lens之间的重复逻辑,提升可维护性。

关键源码片段

sgl-kernel/csrc/cpu/extend.cpp

核心注意力扩展内核,修复因果掩码逻辑错误,防止大序列输入时NaN产生和越界写入。

// apply causal mask
// [Note] condition to apply causal mask.
// Mask any block whose last key (n + n_size - 1) is strictly after the first query position (m), i.e. n + n_size - 1 > m.
// 原条件 num_keys - n <= BLOCK_N 仅在最后一个n-block生效,但BLOCK_M=512, BLOCK_N=768时,首个n-block可能包含未来键。
if (n + n_size - 1 > m) {
  for (int row = 0; row < m_size; ++row) {
    int last_col = m + row - n;
    // [Note] mask the entire row if last_col < 0.
    // 当 n > m + row 时,该块所有键都是未来键,需要掩码整个行。
    // 钳位到-1,避免 last_col+1 <= 0 导致 fill_stub 越界写入。
    last_col = std::max(last_col, -1);
    // fill [last_col + 1, n_size) to -inf
    float* row_ptr = s_i + row * BLOCK_N;
    fill_stub(row_ptr + last_col + 1, -std::numeric_limits<float>::infinity(), n_size - last_col - 1);
  }
}

评论区精华

reviewer gemini-code-assist[bot] 指出测试辅助函数存在代码重复问题:

“The new helper function _test_extend_attention_fixed_lens shares a lot of logic with the existing _test_extend_attention_once, leading to significant code duplication. To improve maintainability, I recommend refactoring these two methods.”

作者在后续提交中解决了此问题,通过提取公共逻辑简化了测试代码。

风险与影响

  • 技术风险:修改涉及核心注意力路径,可能引入回归(如掩码过度或不足),但新增测试用例降低了风险;额外钳位操作对性能影响可忽略。
  • 影响范围:对用户,解决了长序列处理中的NaN崩溃,提升CPU上模型的稳定性;对系统,确保注意力计算正确性,避免NaN传播;对团队,提供了更健壮的内核实现,便于未来扩展CPU功能。

关联脉络

  • 关联Issue 20051:直接驱动了本PR,报告了GDN内核的NaN问题,本PR修复了更通用的注意力内核问题。
  • 历史PR关联:与PR 22842(新增CPU内核)和PR 22990(修复调度逻辑)类似,都涉及核心模块的bugfix和优化,反映仓库在CPU支持和系统稳定性上的持续演进。近期PR如22406(优化KV缓存)也展示了内核层的性能改进趋势。

参与讨论