Prhub

#41882 Add NVFP4 all-gather GEMM fusion for AsyncTP

原始 PR 作者 baonudesifeizhai 合并时间 2026-05-10 09:13 文件变更 7 提交数 10 评论 7 代码增减 +605 / -6

执行摘要

为 AsyncTP 添加 NVFP4 all-gather GEMM 融合路径

PR 描述指出,该 PR 将 NVFP4 FlashInfer all-gather + GEMM 路径接入 AsyncTP,以覆盖 SP + AsyncTP 在 NVFP4 下的性能收益。由于 PyTorch 缺少 NVFP4 感知的 fused GEMM + reduce-scatter 支持,当前仅实现 all-gather 融合。

推荐精读,尤其关注 collective_fusion.pyFlashInferAllGatherFP4Patternpatternreplacement 设计,以及 sequence_parallelism.py 中 NVFP4 量化与序列平行的整合方式。对推理性能优化感兴趣的同学可以关注 reduce-scatter 融合的后续进展。

讨论亮点
  • 对称内存建议:gemini-code-assist[bot] 建议对中间缓冲区使用对称内存以避免不必要的拷贝,作者未公开回应但当前实现仍使用 new_empty
  • Double view 逻辑简化:bot 指出 float8_uint8 双重 view 可简化,建议直接 view 为 uint8。后续 commit 中变量命名已调整(从 STATIC_FP4_QUANT_OP 改为 SCALED_FP4_QUANT_OUT_OVERLOAD),但 double view 是否完全清理需确认。
  • Reduce-scatter 可行性:ProExpertProg 提问 reduce-scatter 是否也是 trivially 可支持的,自答输入已经是列并行,无需额外 scale 通信,但当前保持 disable,留作后续。
  • CI 测试覆盖:ProExpertProg 要求将新测试加入正确性 CI,PR 已包含测试但需要确认是否纳入了 CI 流水线。

实现拆解

  1. vllm/utils/flashinfer.py 中添加 flashinfer_scaled_fp4_mm_out 函数,封装了 FlashInfer mm_fp4_ 的调用,支持 pre-allocated output buffer。

  2. collective_fusion.py 中添加 NVFP4 专用的 MM 适配器 _flashinfer_fp4_mm_out,以及 fused all-gather + GEMM 操作的真实实现和 fake 实现,支持对称内存式流水线 all-gather。

  3. 定义 FlashInferAllGatherFP4Pattern 类,通过 patternreplacement 匹配 QKV/MLP 中的 all-gather + MM + FP4 量化组合,注册到 FusionPass。

  4. sequence_parallelism.py 中添加 FirstAllReduceRMSNormStaticNVFP4PatternMiddleAllReduceRMSNormStaticNVFP4Pattern,支持 NVFP4 量化在 SP 中的 AllReduce→RMSNorm→Quant 模式融合。

  5. 添加三组测试:

    • test_tp2_async_tp_nvfp4_fusions:验证融合计数。
    • test_async_tp_pass_nvfp4_correctness:正确性对比。
    • test_tp_sp_nvfp4_generation:SP 模式生成测试。
文件 模块 状态 重要度
vllm/compilation/passes/fusion/collective_fusion.py 编译融合 modified 8.93
vllm/compilation/passes/fusion/sequence_parallelism.py 编译融合 modified 8.65
vllm/utils/flashinfer.py FlashInfer 工具 modified 7.16
tests/compile/correctness_e2e/test_async_tp.py 集成测试 modified 6.19
tests/compile/fusions_e2e/test_tp2_async_tp.py 集成测试 modified 6.15
tests/compile/correctness_e2e/test_sequence_parallel.py 集成测试 modified 5.99
tests/compile/fullgraph/test_toy_llama.py 单测 modified 3.88

关键符号

_flashinfer_fp4_mm_out fused_all_gather_flashinfer_fp4_matmul fused_all_gather_flashinfer_fp4_matmul_fake FlashInferAllGatherFP4Pattern FirstAllReduceRMSNormStaticNVFP4Pattern MiddleAllReduceRMSNormStaticNVFP4Pattern flashinfer_scaled_fp4_mm_out test_async_tp_pass_nvfp4_correctness test_tp2_async_tp_nvfp4_fusions test_tp_sp_nvfp4_generation

关键源码片段

vllm/compilation/passes/fusion/collective_fusion.py core-logic

核心变更文件,新增 NVFP4 all-gather + GEMM 融合的自定义操作和模式注册。

# vllm/compilation/passes/fusion/collective_fusion.pydef fused_all_gather_flashinfer_fp4_matmul(
    A_shard: torch.Tensor,
    B: torch.Tensor,
    A_scale_shard: torch.Tensor,
    B_scale: torch.Tensor,
    alpha: torch.Tensor,
    gather_dim: int,
    group_name: str,
    out_dtype: torch.dtype | None = None,
    view_a_scale_as_fp8: bool = False,
    use_8x4_sf_layout: bool = False,
    backend: str = "cutlass",
) -> torch.Tensor:
    # 只支持 gather_dim=0(按行拼接)
    assert gather_dim == 0, "FP4 symm_mem adapter only supports gather_dim=0"
    assert A_shard.ndim == 2 and A_scale_shard.ndim == 2 and B.ndim == 2
​
    # 可选:将 scale 重解释为 FP8 以复用接口
    if view_a_scale_as_fp8:
        A_scale_shard = A_scale_shard.view(torch.float8_e4m3fn)
​
    group = c10d._resolve_process_group(group_name)
    world_size = group.size()
​
    # 预分配完整输出张量
    output = A_shard.new_empty(
        A_shard.shape[0] * world_size,
        B.shape[1],
        dtype=out_dtype or torch.bfloat16,
    )
    output_shards = output.chunk(world_size)
​
    # 分配 all-gather 目标缓冲区(非对称内存,可能带来拷贝开销)
    A = A_shard.new_empty(A_shard.shape[0] * world_size, A_shard.shape[1])
    A_scale = A_scale_shard.new_empty(
        A_scale_shard.shape[0] * world_size,
        A_scale_shard.shape[1],
    )
​
    # 流水线 all-gather:每个 rank 数据到达后立即计算局部 GEMM
    def fp4_shard_consumer(shards: list[torch.Tensor], rank: int) -> None:
        _flashinfer_fp4_mm_out(
            shards[0],
            B,
            scale_a=shards[1],
            scale_b=B_scale,
            alpha=alpha,
            out=output_shards[rank],
            out_dtype=out_dtype,
            use_8x4_sf_layout=use_8x4_sf_layout,
            backend=backend,
        )
​
    # 融合 all-gather 与计算
    torch.distributed._symmetric_memory._pipelined_multi_all_gather(
        [A, A_scale],
        [A_shard, A_scale_shard],
        group_name,
        stream_consumer=fp4_shard_consumer,
    )
    return output
vllm/compilation/passes/fusion/sequence_parallelism.py core-logic

添加 NVFP4 序列并行模式,将 AllReduce→RMSNorm→Quant 转换为 ReduceScatter→RMSNorm→Quant→AllGather。

# vllm/compilation/passes/fusion/sequence_parallelism.pyclass FirstAllReduceRMSNormStaticNVFP4Pattern(_SequenceParallelPatternHelper):
    def get_inputs(self) -> list[torch.Tensor]:
        # 创建示例张量供模式匹配器使用
        input = self.empty([8, 16])
        weight = self.empty([16])
        input_global_scale = self.empty_f32([1, 1])
        quant_output = torch.empty([8, 8], device=self.device, dtype=torch.uint8)
        output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
        return [input, weight, input_global_scale, quant_output, output_scale]
​
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input, weight, input_global_scale, quant_output, output_scale
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            # 原始图 : AllReduce -> RMSNorm -> NVFP4 Quant
            all_reduce = self._all_reduce(input)
            rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
            quant = auto_functionalized(
                SCALED_FP4_QUANT_OUT_OVERLOAD,
                input=rms,
                input_scale=input_global_scale,
                is_sf_swizzled_layout=True,
                output=quant_output,
                output_scale=output_scale,
            )
            return quant[1], all_reduce, quant[2]
​
        def replacement(
            input, weight, input_global_scale, quant_output, output_scale
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            # 替换图 : ReduceScatter -> RMSNorm -> NVFP4 Quant -> AllGather
            reduce_scatter = self._reduce_scatter(input)
            rms = vllm.ir.ops.rms_norm(reduce_scatter, weight, self.epsilon)
            rms = torch.ops.aten.view.default(rms, [-1, rms.shape[-1]])
            quant = SCALED_FP4_QUANT_DEFAULT_OVERLOAD(
                rms, input_global_scale, True,
            )
            return (
                self._all_gather(quant[0]),
                reduce_scatter,
                self._all_gather(quant[1]),
            )
​
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
vllm/utils/flashinfer.py core-logic

新增 `flashinfer_scaled_fp4_mm_out` 函数,封装 FlashInfer FP4 mm 的 `out` 变体调用。

# vllm/utils/flashinfer.pydef flashinfer_scaled_fp4_mm_out(
    a: torch.Tensor,
    b: torch.Tensor,
    block_scale_a: torch.Tensor,
    block_scale_b: torch.Tensor,
    alpha: torch.Tensor,
    out: torch.Tensor,
    out_dtype: torch.dtype | None,
    use_8x4_sf_layout: bool,
    backend: str,
) -> torch.Tensor:
    # 所有张量必须是 2D,并且 out 已经预分配好
    assert a.ndim == 2 and b.ndim == 2 and out.ndim == 2
    assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
    assert a.shape[1] == b.shape[0]
    assert out.shape == (a.shape[0], b.shape[1])
​
    # 对 cutlass / cudnn 后端,将 scale 重解释为 uint8
    if backend in ("cutlass", "cudnn"):
        if block_scale_a.dtype != torch.uint8:
            block_scale_a = block_scale_a.view(torch.uint8)
        if block_scale_b.dtype != torch.uint8:
            block_scale_b = block_scale_b.view(torch.uint8)
​
    from flashinfer import mm_fp4 as flashinfer_mm_fp4_
    # 调用 FlashInfer 的 FP4 matmul,输出写入预先分配的 out
    flashinfer_mm_fp4_(
        a, b, block_scale_a, block_scale_b, alpha,
        out_dtype or out.dtype,
        out=out,
        block_size=16,
        use_8x4_sf_layout=use_8x4_sf_layout,
        backend=backend,
    )
    return out

评论区精华

使用对称内存优化中间缓冲区 性能

gemini-code-assist[bot] 指出 fused_all_gather_flashinfer_fp4_matmul 中中间缓冲区 A 和 A_scale 使用 new_empty 分配,建议使用对称内存避免拷贝,并提及频繁分配大缓冲区的开销。

结论:未在 PR 中明显采纳,当前仍使用 new_empty。 · unresolved

简化 a_scale_view 的 double view 逻辑 style

gemini-code-assist[bot] 指出 double view (float8_uint8) 多余,建议直接 view uint8。

结论:后续可能已调整(从 STATIC_FP4_QUANT_OP 重命名),但 double view 是否完全消除需确认。 · addressed

NVFP4 reduce-scatter 融合是否必要 设计

ProExpertProg 提问 reduce-scatter 是否 trivial,自答输入已列并行,无需 scale 通信,但当前特意不启用。

结论:确认设计有意,留作后续。 · 已解决

将 NVFP4 测试加入 CI 测试

ProExpertProg 要求将新测试加入 SP 和 AsyncTP 正确性 CI。

结论:测试已添加,但需要确保 CI 配置包含这些测试。 · addressed

量化操作符命名 style

ProExpertProg 评论 'Lol I don't think this is static vs dynamic, these are just the overloads',暗示 STATIC_FP4_QUANT_OP 命名不当。

结论:后续提交中已改为 SCALED_FP4_QUANT_OUT_OVERLOAD。 · 已解决

风险与影响

  • 量化类型检查:NVFP4 相关代码仅在 torch.ops._C.scaled_fp4_quant 存在时注册,若 FlashInfer 不可用或设备不支持则跳过,不会影响现有 FP8 路径。
  • 性能风险:中间缓冲区使用 new_empty 而非对称内存,可能带来额外拷贝开销,削弱 AsyncTP 收益(尤其小 batch 场景)。
  • 硬件依赖:仅 Blackwell(SM100)和 FlashInfer 支持,其他平台无影响。
  • 测试覆盖:目前仅测试 TP=2,缺少更大规模 TP 或不同序列长度组合的测试。
  • 用户:使用 NVFP4 量化的 Llama 等模型在启用 SP + AsyncTP 时获得 0.89%-13.54% 的吞吐提升,无需修改模型代码。
  • 系统:新增的 fused 操作会增加编译通道的模式匹配负担,但仅在 NVFP4 路径下生效。
  • 团队:降低了后续添加 reduce-scatter 融合的门槛,为其他量化类型提供了参考模式。
仅 Blackwell 支持 依赖 FlashInfer 对称内存未使用 Reduce-scatter 未融合 仅 2 GPU 测试

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论