Prhub

#42153 [Perf] Use 2D-grid to eliminate divmod in W8W8 group quant

原始 PR 作者 jiahanc 合并时间 2026-05-12 22:01 文件变更 1 提交数 3 评论 4 代码增减 +69 / -40

执行摘要

用 2D 网格消除 W8A8 分组量化中的 divmod 计算

当前 1D 网格启动的内核使用 global_group_id % padded_groups_per_rowglobal_group_id / padded_groups_per_row 计算分组索引,这些除法/取模是运行时开销。由于每个线程块的图块大小在编译时已知,改用 2D 网格可将索引计算转换为简单的乘法/加法,编译器进一步优化为移位/位运算,从而减少指令数,带来吞吐量提升(见 PR body 性能对比表)。

值得精读:展示了 CUDA 内核优化中利用网格和模板常量消除运行时除法的典型手法。对理解 GPU 性能优化有参考价值,尤其是 2D grid 的应用和编译时常量的使用。

讨论亮点

gemini-code-assist 指出 LAUNCH_REG_KERNEL 宏假设 kx 只能是 16/8/4,建议增加静态断言确保 kx * ry == 16。作者未回复,但 PR 已被合并,表明当前调用面满足假设。该讨论提示在模板分支中应增加鲁棒性检查。

实现拆解

  1. 新增 GetGroupsPerBlockX 辅助函数(csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu),根据 padded_groups_per_row 选择 X 向图块大小(16/8/4),保证 kx * ry == 16
  2. 修改内核模板 per_token_group_quant_8bit_packed_register_kernel,增加编译时常量 kGroupsPerBlockXkRowsPerBlock,移除运行时参数 groups_per_block
  3. 内核体内将一维 local_group_id 拆解为 sf_k_local = local_group_id % kGroupsPerBlockXrow_local = local_group_id / kGroupsPerBlockX,全局索引改为 blockIdx.x * kGroupsPerBlockX + sf_k_localblockIdx.y * kRowsPerBlock + row_local,避免对 padded_groups_per_row 的除法和取模。
  4. per_token_group_quant_8bit_packed 启动函数根据 GetGroupsPerBlockX 返回值分支实例化模板,设置 2D grid 尺寸(padded_groups_per_row / kGroupsPerBlockXtma_aligned_mn / kRowsPerBlock),线程数固定为 16 * THREADS_PER_GROUP = 128
文件 模块 状态 重要度
csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu 量化内核 modified 5.62

关键符号

GetGroupsPerBlockX per_token_group_quant_8bit_packed_register_kernel per_token_group_quant_8bit_packed

关键源码片段

csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu core-logic

所有变更都在这个文件中:新增辅助函数 GetGroupsPerBlockX,修改内核模板增加编译时常量,替换索引计算逻辑,调整启动配置。

// 根据 padded_groups_per_row 选择一个 <= 16 的最大因子,
// 使得 kx * ry 恒为 16(每组可容纳 2 图块,每个图块 8 线程 × 2 组)。
// 因为 padded_groups_per_row 总是 4 的倍数,结果只会在 16/8/4 中选择。
inline int GetGroupsPerBlockX(int64_t padded_groups_per_row) {
  if (padded_groups_per_row % 16 == 0) {
    return 16; // 每行可分配 16 组 X 向图块
  }
  if (padded_groups_per_row % 8 == 0) {
    return 8; // 8 组 X 向图块,Y 向自动取 2(16 / 8 == 2)
  }
  return 4; // 4 组 X 向图块,Y 向自动取 4(16 / 4 == 4)
}
template <typename T, typename DST_DTYPE, int GROUP_SIZE,
          int kGroupsPerBlockX, int kRowsPerBlock>
__global__ void per_token_group_quant_8bit_packed_register_kernel(
    const T* __restrict__ input, void* __restrict__ output_q,
    unsigned int* __restrict__ output_s_packed,
    const int padded_groups_per_row, const int groups_per_row,
    const int mn, const int output_q_mn_extent,
    const int tma_aligned_mn, const int64_t num_scale_elems,
    const float eps, const float min_8bit, const float max_8bit) {
  constexpr int THREADS_PER_GROUP = 8;
  constexpr int VEC_SIZE = 32 / sizeof(T); // 对于 bf16/fp16 = 16  const int local_group_id = threadIdx.x / THREADS_PER_GROUP;
  const int lane_id = threadIdx.x % THREADS_PER_GROUP;  // 将线程块内的一维 local_group_id 拆解为 X 和 Y 两个维度
  const int sf_k_local = local_group_id % kGroupsPerBlockX;
  const int row_local = local_group_id / kGroupsPerBlockX;
  // 全局索引 = 块起始 + 块内偏移,无需运行时除法和取模
  const int sf_k_idx = blockIdx.x * kGroupsPerBlockX + sf_k_local;
  const int mn_idx = blockIdx.y * kRowsPerBlock + row_local;  if (mn_idx >= tma_aligned_mn) {
    return;
  }
  // ... 后续量化计算不变
}

评论区精华

LAUNCH_REG_KERNEL 宏的假设鲁棒性 设计

gemini-code-assist 指出 LAUNCH_REG_KERNEL 宏假设 kx 只能是 16/8/4,建议增加静态断言确保 kx * ry == 16。

结论:作者未回复,PR 被合并,假设在当前调用下成立,但长期应增加显式检查。 · 已解决

风险与影响

风险较低:变更局限于单个 CUDA 文件,算法逻辑不变,仅改变索引计算方式。但 GetGroupsPerBlockX 后备逻辑假设 padded_groups_per_row 为 4 的倍数,若未来某形状不满足该条件,会静默使用 kx=4, ry=4 配置,可能带来意外的性能回退(但不会崩溃)。内核缺少显式静态断言来验证 kx * ry == 16

仅影响 per_token_group_quant_8bit_packed 内核路径,用户无 API 变更。推理吞吐量在特定 shape 下提升 1-10%,对大 batch 推理收益更明显(DeepSeek V4 等模型)。团队可通过微基准复现优化效果。

假设 padded_groups_per_row 为 4 的倍数 缺少异常处理路径

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论