Prhub

#41326 Faster per-token fp8 group quant packed kernel for blackwell

原始 PR 作者 zyongye 合并时间 2026-05-01 09:09 文件变更 3 提交数 8 评论 5 代码增减 +366 / -89

执行摘要

Blackwell FP8 分组量化寄存器内核,加速 60%-2x

为 DeepGEMM MoE 的输入量化路径提供更快的 fp8 分组量化打包内核,在 Blackwell GPU 上实现 60%-2x 加速(根据批处理大小)。

该 PR 值得精读,特别是 CUDA 寄存器优化和 shuffle 规约技术。评审中的三个高优先级问题展示了正确性与文档的平衡。后续同类优化可借鉴其 alignasint64_t 的前置检查。

讨论亮点

gemini-code-assist[bot] 提出三个高优先级问题:

  • regs 数组需要 alignas(16) 保证 uint4 加载对齐(正确性)。
  • num_groups_paddednum_scale_elems 应使用 int64_t 防止大输入溢出(潜在 bug)。
  • 头文件注释声称 group_size!=128 会 fallback 到 smem 内核,但实际已报错,文档与实现不一致(回归风险)。
    claude[bot] 指出上述三个问题关键且未被解决,并补充了注释中拼写换行的纯风格 nits。作者在后续提交中全部修复(commit d2cb322)。

实现拆解

  1. per_token_group_quant.cu 中引入寄存器驻留内核 per_token_group_quant_8bit_packed_register_kernel,每个线程持有 32 字节源数据(uint4 x 2)在寄存器中完成 absmax 规约、缩放计算和量化,消除共享内存和 __syncthreads
  2. 移除旧版共享内存内核 per_token_group_quant_8bit_packed_kernel 及其 host 辅助函数,清理 per_token_group_quant_8bit_packed_smem_impl
  3. 将每组的线程数从 16 降至 8,每线程加载量翻倍至 32 字节,shuffle 规约层级减少,提升寄存器利用率。
  4. 在内核中添加 padding 零填充逻辑:当 output_q 存在 TMA 对齐的额外 mn 行时,内核在量化后原位清零这些行,确保下游无垃圾数据。
  5. 根据 review 反馈修复:加入 alignas(16) 保证 uint4 加载对齐;将 num_groups_paddednum_scale_elemsint 改为 int64_t 避免大输入溢出;更新头文件注释,明确 group_size!=128 时报错而非 fallback。
  6. 测试文件新增三个测试用例(all_zero、mantissa_rounds_up、zero_fills_padded_output_q),移除非 2 次幂 group size 参数(96)以匹配内核限制。
文件 模块 状态 重要度
csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu 量化内核 modified 6.19
csrc/libtorch_stable/quantization/w8a8/per_token_group_quant_8bit.h 量化接口 modified 5.68
tests/kernels/quantization/test_per_token_group_quant.py 量化测试 modified 6.91

关键符号

per_token_group_quant_8bit_packed_register_kernel per_token_group_quant_8bit_packed test_per_token_group_quant_fp8_packed_all_zero test_per_token_group_quant_fp8_packed_mantissa_rounds_up test_per_token_group_quant_fp8_packed_zero_fills_padded_output_q

关键源码片段

csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu core-logic

核心变更文件:实现寄存器驻留 packed 量化内核,移除旧共享内存内核,调整 host 侧调度逻辑。

// Register-resident fast path for group_size==128.
// 每个线程持有 16 个源元素(32 B,两个 uint4)在寄存器中完成:
// absmax 规约 → 缩放计算 → 量化流水线。不使用共享内存和 __syncthreads。
// UE8M0 缩放因子通过位运算提取(与 exp2f(ceilf(log2f)) 位精确)。
// Blackwell nvcc 会将连续 uint4 加载融合为单条 256-bit LDG.E.256 指令。
template <typename T, typename DST_DTYPE, int GROUP_SIZE>
__global__ void per_token_group_quant_8bit_packed_register_kernel(
    const T* __restrict__ input,
    void* __restrict__ output_q,
    unsigned int* __restrict__ output_s_packed,
    const int64_t num_groups_padded,
    const int groups_per_block,
    const int padded_groups_per_row,
    const int groups_per_row,
    const int mn,
    const int output_q_mn_extent,
    const int tma_aligned_mn,
    const int64_t num_scale_elems,
    const float eps,
    const float min_8bit,
    const float max_8bit) {
  static_assert(GROUP_SIZE == 128, "fast path supports GROUP_SIZE==128");
  constexpr int THREADS_PER_GROUP = 8;
  constexpr int VEC_SIZE = 32 / sizeof(T); // 对于 bf16/fp16 为 16
  /* 每个组的 8 个线程必须位于同一个 warp 八位组内,使得 shuffle mask
     0xffu << (threadIdx.x & 24u) 恰好选中同组线程。要求
     32 % THREADS_PER_GROUP == 0 且启动线程数为 group_per_block * 8。 */
  static_assert(32 % THREADS_PER_GROUP == 0);
  // ... 后续实现包括 absmax 规约、缩放计算、量化和存储
}
csrc/libtorch_stable/quantization/w8a8/per_token_group_quant_8bit.h core-logic

头文件添加新函数声明 `per_token_group_quant_8bit_packed`,并更新注释说明限制。

#pragma once#include <torch/csrc/stable/tensor.h>// 8-bit 逐 token 分组量化辅助函数(INT8 和 FP8 共用)
void per_token_group_quant_8bit(const torch::stable::Tensor& input,
                                torch::stable::Tensor& output_q,
                                torch::stable::Tensor& output_s,
                                int64_t group_size, double eps, double min_8bit,
                                double max_8bit, bool scale_ue8m0 = false);// Public op: 面向 DeepGEMM Blackwell 路径的寄存器驻留 packed 量化。
// 仅支持 group_size == 128 且 bf16/fp16 输入;其他配置直接抛出 STD_TORCH_CHECK。
// 旧版共享内存回退已被移除,因为无生产调用方使用其他形状。
void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
                                       torch::stable::Tensor& output_q,
                                       torch::stable::Tensor& output_s_packed,
                                       int64_t group_size, double eps,
                                       double min_8bit, double max_8bit);
tests/kernels/quantization/test_per_token_group_quant.py test-coverage

测试文件新增三个边界测试用例,移除不支持的 group size 参数化。

@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
)
def test_per_token_group_quant_fp8_packed_all_zero():
    """全零输入必须通过内核 UE8M0 路径中的 eps 下界产生确定的 UE8M0 缩放字节。
    在优化前锁定全零行为。    CUDA 内核计算:
        y_s = eps / fp8_max
        y_s = exp2(ceil(log2(fmax(y_s, 1e-10))))
    对于全零输入,eps/fp8_max < 1e-10,内部 fmax 钳回到 1e-10,得到
    exp2(ceil(log2(1e-10))) = exp2(-33) => UE8M0 字节 0x5E (94)。
    """
    device = "cuda"
    num_tokens, hidden_dim, group_size = 4, 7168, 128
    x = torch.zeros((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16)
​
    out_q, out_s_packed = fp8_utils.per_token_group_quant_fp8_packed_for_deepgemm(
        x, group_size=group_size, use_ue8m0=True,
    )
​
    # 量化输出必须全零
    assert torch.equal(
        out_q.view(torch.uint8),
        torch.zeros_like(out_q, dtype=torch.uint8),
    ), "All-zero input should produce all-zero FP8 output"
​
    # 全零输入时内核产生的 UE8M0 字节
    expected_exp_byte = 0x5E
​
    mn = num_tokens
    groups_per_row = hidden_dim // group_size
    k_num_packed = (groups_per_row + 3) // 4
    tma_aligned_mn = ((mn + 3) // 4) * 4
    num_scale_elems = mn + (k_num_packed - 1) * tma_aligned_mn
​
    actual = torch.as_strided(out_s_packed, (num_scale_elems,), (1,)).cpu()
    expected = torch.zeros(num_scale_elems, dtype=torch.int32, device="cpu")
    for row in range(mn):
        for g in range(groups_per_row):
            pack_col = g // 4
            pos = g % 4
            idx = pack_col * tma_aligned_mn + row
            expected[idx] |= expected_exp_byte << (pos * 8)
​
    assert torch.equal(actual, expected), "All-zero scale bytes mismatch"

评论区精华

寄存器数组对齐要求 alignas(16) 正确性

gemini-code-assist[bot] 指出 `regs` 数组通过 uint4* 强制转换进行向量化加载,需要 16 字节对齐,否则可能引发对齐故障或未定义行为。

结论:作者在 commit d2cb322 中添加了 alignas(16) 修复。 · 已解决

参数类型 int 应改为 int64_t 防止溢出 正确性

gemini-code-assist[bot] 指出 `num_groups_padded` 和 `num_scale_elems` 在 host 端计算为 int64_t,但 kernel 签名使用 int,大输入会溢出。

结论:作者在后续提交中将 kernel 签名改为 int64_t。 · 已解决

头文件注释与实现不符:fallback 误导 documentation

gemini-code-assist[bot] 发现头文件注释声称 group_size!=128 时 fallback 到 smem 内核,但实际已移除 smem 内核并直接报错。

结论:作者更新注释,明确限制并移除 fallback 描述。 · 已解决

注释换行拼写风格问题 style

claude[bot] 指出 kernel 注释中的换行导致 `1e-10f` 与 `exp_byte` 连接,纯风格问题。

结论:作者在后续提交中修复。 · 已解决

风险与影响

  1. group_size 硬限制:新内核仅支持 group_size==128,非 128 配置直接报错。若上游调用方(如 input_quant_fp8)可能使用其他 group size,则会功能回归。但 PR 说明无生产路径使用非 128 分组。
  2. 对齐未定义行为:若 alignas(16) 未正确应用,uint4 加载可能触发硬件异常或性能降级——已在 review 后修复。
  3. 输入溢出:原本 int 类型在超大上下文时可能溢出,已改为 int64_t
  4. 平台限制:新内核依赖 CUDA 12.8+ 指令(256-bit 加载),且测试仅在 CUDA 平台上运行;AMD 或其他后端不可用。
  5. 测试覆盖:新测试仅覆盖 CUDA 平台,但 DeepGEMM not available on this platform 的 skip 条件合理。

用户影响:Blackwell GPU 上使用 DeepGEMM MoE 的服务将获得显著加速(60%-2x),无需任何配置变更。非 Blackwell 或 group_size!=128 的用户不受影响(原有路径保持不变)。系统影响:内核二进制体积减少(移除 smem 变体),启动时间可能略有改善。团队影响:维护性降低,因为代码行数减少且依赖条件集中;但未来若需支持其他 group size 需重新设计。

去除共享内存回退可能影响非 128 组大小 对齐问题已修复 大输入溢出已修复

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论