Prhub

#22434 [CPU][sgl-kernel] `extend_attention_cpu` and `flash_attn_varlen_func`: fix `nan` for large seq

原始 PR 作者 chunyuan-w 合并时间 2026-04-17 13:01 文件变更 4 提交数 10 评论 3 代码增减 +99 / -8

执行摘要

修复 CPU 内核中因果掩码逻辑错误,解决大序列输入时 NaN 问题。

Issue 20051报告了在CPU上使用BF16精度时,当prefill长度超过约4096 tokens时,GDN内核产生NaN值,导致torch.multinomial崩溃。PR body指出在extend_attention_cpu和flash_attn_varlen_func内核中,因果掩码逻辑错误导致大序列输入时出现NaN,具体表现为last_col为负或掩码条件不充分。

该PR值得精读,重点关注因果掩码条件从num_keys - n <= BLOCK_Nn + n_size - 1 > m的设计变更,这揭示了块状注意力中处理未来键的通用模式。工程师应学习如何通过钳位last_col避免越界写入,并在测试中覆盖边界情况。

讨论亮点

reviewer gemini-code-assist[bot] 指出测试辅助函数 _test_extend_attention_fixed_lens 与现有函数 _test_extend_attention_once 存在代码重复,建议重构以提升可维护性。作者在后续提交中通过提取公共逻辑解决了此问题,确保代码更清晰。

实现拆解

  1. 修改因果掩码条件:在sgl-kernel/csrc/cpu/extend.cppsgl-kernel/csrc/cpu/flash_attn.cpp中,将原有的掩码条件num_keys - n <= BLOCK_N更新为n + n_size - 1 > m,以正确识别需要掩码的键块(即最后一个键位置大于首个查询位置的块)。
  2. 钳位last_col到-1:在同一代码段中,添加last_col = std::max(last_col, -1),防止当last_col为负时fill_stub越界写入,确保整个行被掩码。
  3. 扩展测试覆盖:在test/srt/cpu/test_extend.pytest/srt/cpu/test_flash_attn.py中,新增测试用例test_extend_attention_large_seq_causal_masktest_flash_attn_large_seq_causal_mask,使用序列长度5000+验证修复。
  4. 测试辅助函数重构:根据review反馈,通过提交refactor ut to remove duplicated code简化测试代码,减少重复逻辑。
文件 模块 状态 重要度
sgl-kernel/csrc/cpu/extend.cpp 注意力内核 modified 6.44
sgl-kernel/csrc/cpu/flash_attn.cpp 注意力内核 modified 5.91
test/srt/cpu/test_extend.py 测试覆盖 modified 5.77
test/srt/cpu/test_flash_attn.py 测试覆盖 modified 5.7

关键符号

extend_attention_kernel_impl flash_attn_kernel_impl flash_attn_varlen_kernel_impl

关键源码片段

sgl-kernel/csrc/cpu/extend.cpp core-logic

核心注意力扩展内核,修复因果掩码逻辑错误,防止大序列输入时 NaN 产生和越界写入。

// apply causal mask
// [Note] condition to apply causal mask.
// Mask any block whose last key (n + n_size - 1) is strictly after the first query position (m), i.e. n + n_size - 1 > m.
// 原条件 num_keys - n <= BLOCK_N 仅在最后一个 n-block 生效,但 BLOCK_M=512, BLOCK_N=768 时,首个 n-block 可能包含未来键。
if (n + n_size - 1 > m) {
  for (int row = 0; row < m_size; ++row) {
    int last_col = m + row - n;
    // [Note] mask the entire row if last_col < 0.
    // 当 n > m + row 时,该块所有键都是未来键,需要掩码整个行。
    // 钳位到 -1,避免 last_col+1 <= 0 导致 fill_stub 越界写入。
    last_col = std::max(last_col, -1);
    // fill [last_col + 1, n_size) to -inf
    float* row_ptr = s_i + row * BLOCK_N;
    fill_stub(row_ptr + last_col + 1, -std::numeric_limits<float>::infinity(), n_size - last_col - 1);
  }
}

评论区精华

测试辅助函数代码重复 设计

reviewer gemini-code-assist[bot] 指出新 helper 函数 _test_extend_attention_fixed_lens 与现有函数 _test_extend_attention_once 存在大量逻辑重复,建议重构以提升可维护性。

结论:作者通过提交 'refactor ut to remove duplicated code' 重构了测试代码,提取公共逻辑,解决了重复问题。 · 已解决

风险与影响

修复涉及核心注意力计算路径,风险包括:回归风险(如果新掩码条件未覆盖所有边界情况,可能导致掩码不全或过度掩码)、性能影响(额外钳位操作可能轻微增加CPU开销,但可忽略)、兼容性问题(修改可能影响所有CPU上的注意力操作,但通过测试验证了正确性)。

对用户:解决了大序列输入时NaN崩溃问题,提升模型在CPU上的稳定性和长上下文处理能力。对系统:确保注意力计算正确性,避免因NaN传播导致的后续采样错误。对团队:提供了更健壮的内核实现,为后续CPU优化奠定基础。

核心路径变更 边界条件处理 测试覆盖增强

关联 Issue

#20051 [Bug] [CPU] GDN (chunk_gated_delta_rule_cpu) produces NaN with BF16 when prefill exceeds ~4096 tokens

完整报告

参与讨论