Prhub

#25093 [AMD] Enable AITER custom all-gather on ROCm

原始 PR 作者 hubertlu-tw 合并时间 2026-06-03 06:57 文件变更 4 提交数 8 评论 17 代码增减 +625 / -0

执行摘要

在 ROCm 上集成 AITER 自定义 all-gather,加速 TP 通信

PR body指出,SGLang在ROCm上现有TP all-gather使用RCCL路径,但AITER提供的自定义all-gather在DeepSeek-R1-MXFP4-Preview服务的logits形状下更快且输出bit一致。因此希望新增一条HIP门控的AITER路径,类似已有的AITER custom all-reduce集成方式。

值得精读。该PR展示了在大型项目中安全集成第三方加速库的范例:环境变量开关、完备的fallback、CUDA图各阶段一致性处理、以及配套的benchmark和CI测试。_all_gather_into_tensor中的条件编排和状态分支设计可供参考。

讨论亮点
  • gemini-code-assist[bot]指出CUDA图warmup阶段如果_IS_CAPTURING为True但torch.cuda.is_current_stream_capturing()为False,代码会fallthrough到NCCL,导致warmup用NCCL而实际捕获用AITER,可能引发问题。hubertlu-tw确认并重构代码,最终通过output.zero_()+return避免在warmup时调用任何collective。
  • hubertlu-tw在dtype guard上注释:当前仅支持浮点类型,未来AITER kernel扩展后可放宽限制。

实现拆解

  1. 环境变量注册:在python/sglang/srt/environ.pyEnvs类中新增SGLANG_USE_AITER_AG = EnvBool(True),默认启用。
  2. 核心逻辑改造:在python/sglang/srt/distributed/parallel_state.pyGroupCoordinator._all_gather_into_tensor方法开头插入AITER分支。条件检查:ROCm平台、env启用、ca_comm存在且拥有should_custom_ag等方法、输入输出连续、dtype为float32/16/bfloat16、should_custom_ag返回True。当CUDA图正在捕获时调用all_gather_reg,piecewise图调用all_gather_unreg,普通warmup阶段则output.zero_()并返回(避免混合使用NCCL和AITER导致图不一致)。不满足条件时fallback到原有pynccl或torch.distributed.all_gather_into_tensor
  3. 辅助方法:新增_has_aiter_custom_all_gather检查ca_comm具备必要方法且未禁用确定性collectives;_deterministic_collectives_enabled判断确定性推理配置。
  4. Benchmark与测试:新增benchmark/kernels/all_gather/benchmark_aiter.py,支持多shape/dtype/dim扫描,并在计时前进行正确性检查(torch.equal对比RCCL与AITER输出)。新增test/registered/ops/test_aiter_allgather_amd.py,作为CI测试调用benchmark脚本在TP=2,4,8下验证多种dtype的正确性。
  5. 配套修复:根据review意见重构了CUDA图捕获/预热路径,避免warmup时走NCCL而捕获时走AITER的不一致问题;同时添加了dtype guard防止非float类型进入AITER(AITER仅支持float32/16/bfloat16)。
文件 模块 状态 重要度
python/sglang/srt/distributed/parallel_state.py 分布式 modified 7.72
benchmark/kernels/all_gather/benchmark_aiter.py 基准测试 added 8.78
test/registered/ops/test_aiter_allgather_amd.py 测试 added 6.93
python/sglang/srt/environ.py 配置 modified 3.95

关键符号

_all_gather_into_tensor _has_aiter_custom_all_gather _deterministic_collectives_enabled reshape_logical raw_allgather_shape test_aiter_allgather_matches_rccl

关键源码片段

python/sglang/srt/distributed/parallel_state.py core-logic

核心改动:在 _all_gather_into_tensor 中插入 AITER 自定义 all-gather 分支,并新增 _has_aiter_custom_all_gather 等辅助方法。包含 CUDA 图状态处理逻辑。

def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
    # AITER custom all-gather (ROCm). Set SGLANG_USE_AITER_AG=0 to disable.
    # AITER's should_custom_ag handles shape/layout validation:
    # 16B alignment, weak-contiguous, supported topology, and per-rank
    # size <= max_size/(world*2).
    ca_comm = self.ca_comm
    if (
        is_hip()
        and envs.SGLANG_USE_AITER_AG.get()
        and self._has_aiter_custom_all_gather()
        and input.is_contiguous()
        and output.is_contiguous()
        # AITER only supports float types; int tensors fall through to NCCL
        and input.dtype in (torch.float32, torch.float16, torch.bfloat16)
        and ca_comm.should_custom_ag(input)
    ):
        if getattr(ca_comm, "_IS_CAPTURING", False):
            if torch.cuda.is_current_stream_capturing():
                # Actual capture: register buffer for graph replay
                ca_comm.all_gather_reg(input, out=output, dim=0)
            elif is_in_piecewise_cuda_graph():
                # Piecewise graph: use unregistered version
                ca_comm.all_gather_unreg(input, out=output, dim=0)
            else:
                # Warmup: zero out output to keep graph shape, avoid NCCL
                output.zero_()
            return
        else:
            # Eager mode: unregistered variant
            ca_comm.all_gather_unreg(input, out=output, dim=0)
            return
​
    # Fallback: pynccl or torch.distributed
    pynccl_comm = self.pynccl_comm
    if pynccl_comm is not None and (
        not pynccl_comm.disabled or self.is_symmetric_memory_enabled()
    ):
        ... # existing path
    else:
        torch.distributed.all_gather_into_tensor(
            output, input, group=self.device_group
        )
​
​
def _has_aiter_custom_all_gather(self) -> bool:
    if self._deterministic_collectives_enabled():
        return False
    ca_comm = self.ca_comm
    return (
        ca_comm is not None
        and not getattr(ca_comm, "disabled", True)
        and hasattr(ca_comm, "should_custom_ag")
        and hasattr(ca_comm, "all_gather_reg")
        and hasattr(ca_comm, "all_gather_unreg")
    )@staticmethod
def _deterministic_collectives_enabled() -> bool:
    if envs.SGLANG_USE_1STAGE_ALLREDUCE.is_set():
        return envs.SGLANG_USE_1STAGE_ALLREDUCE.get()
    return envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get()

评论区精华

CUDA 图 warmup 阶段可能的不一致问题 正确性

gemini-code-assist[bot] 指出,当 _IS_CAPTURING 为 True 但不在捕获中时(warmup),代码会 fallthrough 到 NCCL,导致 warmup 用 NCCL 而实际捕获用 AITER,可能引发问题。

结论:hubertlu-tw 同意并重构:在 warmup 分支中用 output.zero_()+return,避免使用任何 collective。 · 已解决

dtype 限制与未来扩展 设计

hubertlu-tw 注释说明当前仅支持浮点类型,未来 AITER kernel 扩展后可解除限制。

结论:保留 dtype guard,文档化限制。 · 已解决

风险与影响

  • 平台限制:仅AMD ROCm/HIP平台生效,其他平台不受影响。
  • 第三方依赖:需要AITER库及ca_comm提供should_custom_agall_gather_regall_gather_unreg等方法;若AITER不可用或通信器类型不支持,自动回退到RCCL。
  • dtype限制:AITER仅支持float32/16/bfloat16,int类型会触发回退。PR已在条件中显式检查dtype,避免误入AITER导致崩溃(review中已发现并修复该问题)。
  • CUDA图一致性:最终代码清晰区分捕获、piecewise图、warmup三种状态,避免混合使用不同collective导致图无效。
  • 回归风险:仅新增分支,现有路径完全保留;但如果should_custom_ag在非预期形状返回True可能导致性能损失而非错误。
  • 用户影响:AMD ROCm用户默认获得all-gather性能提升(典型加速1.13-1.71x),可通过设置SGLANG_USE_AITER_AG=0禁用。
  • 系统影响:无外部接口变更,内部通信路径优化。
  • 团队影响:测试和基准脚本可作为后续其他collective优化的模板。
仅 AMD 平台 依赖 AITER 第三方库 CUDA 图一致性已修复 dtype 限制需关注

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论