Prhub

#40772 [Bugfix] Fix IMA in DSA + MTP

原始 PR 作者 WoosukKwon 合并时间 2026-04-24 16:40 文件变更 1 提交数 1 评论 0 代码增减 +14 / -7

执行摘要

修复 DSA+MTP 场景下的 IMA bug

PR #40654 为了避免GPU→CPU同步,引入了seq_lens_cpu_upper_bound,使得kernel中的num_tokens可能大于实际有效token数,导致在DSA+MTP场景下访问无效内存地址,触发IMA错误。

建议精读:该PR展示了如何为性能优化(避免GPU→CPU同步)引入的副作用打补丁,值得关注边界情况处理。

讨论亮点

无讨论。

实现拆解

  1. csrc/cache_kernels.cucp_gather_indexer_k_quant_cache_kernel 中,初始化共享内存数组 batch_idx 为 -1,确保未被赋值的线程不会使用无效索引。
  2. 在循环计算批次索引后,添加 __syncthreads() 保证所有线程的 batch_idx 写入完成。
  3. 将原有的条件 __syncwarp() (仅在非ROCm时生效) 替换为 __syncthreads(),确保所有线程同步。
  4. 增加对 batch < 0 的检查,如果批次索引无效则直接返回,避免访问越界。
文件 模块 状态 重要度
csrc/cache_kernels.cu 内核 modified 4.86

关键符号

cp_gather_indexer_k_quant_cache_kernel

关键源码片段

csrc/cache_kernels.cu core-logic

核心内核函数 `cp_gather_indexer_k_quant_cache_kernel` 的修复,解决了 DSA+MTP 场景下因 num_tokens 上界导致的越界访问。

__global__ void cp_gather_indexer_k_quant_cache_kernel(...) {
    // ...
    __shared__ int batch_idx[BLOCK_Y_SIZE];
    if (threadIdx.x == 0) {
        batch_idx[threadIdx.y] = -1; // 初始化为无效值,防止未更新时使用
    }
    __syncthreads();    for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x)); iter++) {
        int tid = iter * blockDim.x + threadIdx.x;
        if (tid < batch_size) {
            // 某个线程负责写入 batch_idx
            batch_idx[threadIdx.y] = /* 计算 */ ;
        }
    }
    __syncthreads(); // 确保所有线程的 batch_idx 已更新    // num_tokens 可能为分配上界,需校验 batch 有效性
    const int batch = batch_idx[threadIdx.y];
    if (head_idx >= head_dim || token_idx >= num_tokens || batch < 0) {
        return; // batch<0 表示该线程负责的批次索引未初始化,跳过
    }
    // 使用安全的 batch 访问后续数据
    const int inbatch_seq_idx = token_idx - cu_seq_lens[batch];
    // ...
}

说明: 该kernel原先假设所有线程都能找到有效的batch索引,但num_tokens上界可能导致部分线程访问无效batch。通过初始化共享数组为-1,并在访问前检查batch<0,优雅地跳过无效线程。

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

该修改针对特定kernel中的同步机制和边界检查,风险较低。但替换 __syncwarp()__syncthreads() 可能引入性能微损,需确认不影响warp内同步的优化场景。

影响仅限启用DSA+MTP的用户,修复可能导致此类用户之前遇到的IMA崩溃。对其他场景无影响。

核心 kernel 修改 同步语义变更 (__syncwarp->__syncthreads)

关联 Issue

#40654 [Core] Avoid seq_lens_cpu GPU->CPU sync

完整报告

参与讨论