执行摘要
修复 DSA+MTP 场景下的 IMA bug
PR #40654 为了避免GPU→CPU同步,引入了seq_lens_cpu_upper_bound,使得kernel中的num_tokens可能大于实际有效token数,导致在DSA+MTP场景下访问无效内存地址,触发IMA错误。
建议精读:该PR展示了如何为性能优化(避免GPU→CPU同步)引入的副作用打补丁,值得关注边界情况处理。
无讨论。
PR #40654 为了避免GPU→CPU同步,引入了seq_lens_cpu_upper_bound,使得kernel中的num_tokens可能大于实际有效token数,导致在DSA+MTP场景下访问无效内存地址,触发IMA错误。
建议精读:该PR展示了如何为性能优化(避免GPU→CPU同步)引入的副作用打补丁,值得关注边界情况处理。
无讨论。
csrc/cache_kernels.cu 的 cp_gather_indexer_k_quant_cache_kernel 中,初始化共享内存数组 batch_idx 为 -1,确保未被赋值的线程不会使用无效索引。__syncthreads() 保证所有线程的 batch_idx 写入完成。__syncwarp() (仅在非ROCm时生效) 替换为 __syncthreads(),确保所有线程同步。batch < 0 的检查,如果批次索引无效则直接返回,避免访问越界。| 文件 | 模块 | 状态 | 重要度 |
|---|---|---|---|
csrc/cache_kernels.cu |
内核 | modified | 4.86 |
csrc/cache_kernels.cu
core-logic
核心内核函数 `cp_gather_indexer_k_quant_cache_kernel` 的修复,解决了 DSA+MTP 场景下因 num_tokens 上界导致的越界访问。
__global__ void cp_gather_indexer_k_quant_cache_kernel(...) {
// ...
__shared__ int batch_idx[BLOCK_Y_SIZE];
if (threadIdx.x == 0) {
batch_idx[threadIdx.y] = -1; // 初始化为无效值,防止未更新时使用
}
__syncthreads();
for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x)); iter++) {
int tid = iter * blockDim.x + threadIdx.x;
if (tid < batch_size) {
// 某个线程负责写入 batch_idx
batch_idx[threadIdx.y] = /* 计算 */ ;
}
}
__syncthreads(); // 确保所有线程的 batch_idx 已更新
// num_tokens 可能为分配上界,需校验 batch 有效性
const int batch = batch_idx[threadIdx.y];
if (head_idx >= head_dim || token_idx >= num_tokens || batch < 0) {
return; // batch<0 表示该线程负责的批次索引未初始化,跳过
}
// 使用安全的 batch 访问后续数据
const int inbatch_seq_idx = token_idx - cu_seq_lens[batch];
// ...
}
说明: 该kernel原先假设所有线程都能找到有效的batch索引,但num_tokens上界可能导致部分线程访问无效batch。通过初始化共享数组为-1,并在访问前检查batch<0,优雅地跳过无效线程。
当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。
该修改针对特定kernel中的同步机制和边界检查,风险较低。但替换 __syncwarp() 为 __syncthreads() 可能引入性能微损,需确认不影响warp内同步的优化场景。
影响仅限启用DSA+MTP的用户,修复可能导致此类用户之前遇到的IMA崩溃。对其他场景无影响。
参与讨论