执行摘要
修复 SM90 flash_mask_attn 算子 batch_size 推导错误,放宽 shape 校验以兼容预分配输入。
在SM90 flash mask attention算子中,cu_seqlens_q和seq_lens_encoder的输入shape可能按max_batch维度预分配,其实际有效长度可能小于tensor的第一维大小。此时若以cu_seq_q.dims()[0] - 1推导batch_size,会得到一个偏大的值(等于max_batch而非真实batch size),导致后续kernel launch的batch维度不正确。cu_seq_k始终按真实batch size填充,因此需要确保batch_size推导正确。同时,原有的PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0])等断言在预分配场景下会误报失败,需要放宽校验。
该PR值得精读,重点关注:1. Python侧切片方案的设计权衡,以及是否应将修复逻辑移至CUDA侧。2. shape校验放宽的边界条件处理,是否应添加下界校验以避免越界风险。3. 预分配场景下的测试覆盖缺失问题。
讨论主要集中在两个核心点:1. Copilot建议将移除的严格shape校验放宽为下界校验(如seq_len_encoder.dims()[0] >= batch_size),以避免kernel访问越界导致未定义行为。2. fastdeploy-bot指出PR描述与实际变更存在差异:PR描述说将batch_size推导从cu_seq_q改为cu_seq_k,但实际是通过Python侧切片实现的,建议更新描述或将修复逻辑移至CUDA代码以提高可维护性。
参与讨论