执行摘要
- 一句话:在ROCm上集成AITER自定义all-gather,加速TP通信
- 推荐动作:值得精读。该PR展示了在大型项目中安全集成第三方加速库的范例:环境变量开关、完备的fallback、CUDA图各阶段一致性处理、以及配套的benchmark和CI测试。
_all_gather_into_tensor中的条件编排和状态分支设计可供参考。
功能与动机
PR body指出,SGLang在ROCm上现有TP all-gather使用RCCL路径,但AITER提供的自定义all-gather在DeepSeek-R1-MXFP4-Preview服务的logits形状下更快且输出bit一致。因此希望新增一条HIP门控的AITER路径,类似已有的AITER custom all-reduce集成方式。
实现拆解
- 环境变量注册:在
python/sglang/srt/environ.py的Envs类中新增SGLANG_USE_AITER_AG = EnvBool(True),默认启用。
- 核心逻辑改造:在
python/sglang/srt/distributed/parallel_state.py的GroupCoordinator._all_gather_into_tensor方法开头插入AITER分支。条件检查:ROCm平台、env启用、ca_comm存在且拥有should_custom_ag等方法、输入输出连续、dtype为float32/16/bfloat16、should_custom_ag返回True。当CUDA图正在捕获时调用all_gather_reg,piecewise图调用all_gather_unreg,普通warmup阶段则output.zero_()并返回(避免混合使用NCCL和AITER导致图不一致)。不满足条件时fallback到原有pynccl或torch.distributed.all_gather_into_tensor。
- 辅助方法:新增
_has_aiter_custom_all_gather检查ca_comm具备必要方法且未禁用确定性collectives;_deterministic_collectives_enabled判断确定性推理配置。
- Benchmark与测试:新增
benchmark/kernels/all_gather/benchmark_aiter.py,支持多shape/dtype/dim扫描,并在计时前进行正确性检查(torch.equal对比RCCL与AITER输出)。新增test/registered/ops/test_aiter_allgather_amd.py,作为CI测试调用benchmark脚本在TP=2,4,8下验证多种dtype的正确性。
- 配套修复:根据review意见重构了CUDA图捕获/预热路径,避免warmup时走NCCL而捕获时走AITER的不一致问题;同时添加了dtype guard防止非float类型进入AITER(AITER仅支持float32/16/bfloat16)。
关键文件:
python/sglang/srt/distributed/parallel_state.py(模块 分布式;类别 source;类型 core-logic;符号 _has_aiter_custom_all_gather, _deterministic_collectives_enabled): 核心改动:在_all_gather_into_tensor中插入AITER自定义all-gather分支,并新增_has_aiter_custom_all_gather等辅助方法。包含CUDA图状态处理逻辑。
benchmark/kernels/all_gather/benchmark_aiter.py(模块 基准测试;类别 source;类型 dependency-wiring;符号 parse_shape_list, parse_dtype_list, parse_dim_list, parse_args): 新增benchmark脚本,支持多shape/dtype/dim扫描,在计时前验证AITER与RCCL的正确性,是性能数据和CI测试的基础。
test/registered/ops/test_aiter_allgather_amd.py(模块 测试;类别 test;类型 test-coverage;符号 TestAiterAllGatherAmd, _gpu_count, test_aiter_allgather_matches_rccl): CI测试:在TP=2/4/8上运行benchmark脚本的--correctness-only模式,覆盖10种dtype和两种形状,确保AITER路径正确性。
python/sglang/srt/environ.py(模块 配置;类别 source;类型 core-logic): 注册SGLANG_USE_AITER_AG环境变量,默认True,使能AITER all-gather开关。
关键符号:_all_gather_into_tensor, _has_aiter_custom_all_gather, _deterministic_collectives_enabled, reshape_logical, raw_allgather_shape, test_aiter_allgather_matches_rccl
关键源码片段
python/sglang/srt/distributed/parallel_state.py
核心改动:在_all_gather_into_tensor中插入AITER自定义all-gather分支,并新增_has_aiter_custom_all_gather等辅助方法。包含CUDA图状态处理逻辑。
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
# AITER custom all-gather (ROCm). Set SGLANG_USE_AITER_AG=0 to disable.
# AITER's should_custom_ag handles shape/layout validation:
# 16B alignment, weak-contiguous, supported topology, and per-rank
# size <= max_size/(world*2).
ca_comm = self.ca_comm
if (
is_hip()
and envs.SGLANG_USE_AITER_AG.get()
and self._has_aiter_custom_all_gather()
and input.is_contiguous()
and output.is_contiguous()
# AITER only supports float types; int tensors fall through to NCCL
and input.dtype in (torch.float32, torch.float16, torch.bfloat16)
and ca_comm.should_custom_ag(input)
):
if getattr(ca_comm, "_IS_CAPTURING", False):
if torch.cuda.is_current_stream_capturing():
# Actual capture: register buffer for graph replay
ca_comm.all_gather_reg(input, out=output, dim=0)
elif is_in_piecewise_cuda_graph():
# Piecewise graph: use unregistered version
ca_comm.all_gather_unreg(input, out=output, dim=0)
else:
# Warmup: zero out output to keep graph shape, avoid NCCL
output.zero_()
return
else:
# Eager mode: unregistered variant
ca_comm.all_gather_unreg(input, out=output, dim=0)
return
# Fallback: pynccl or torch.distributed
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and (
not pynccl_comm.disabled or self.is_symmetric_memory_enabled()
):
... # existing path
else:
torch.distributed.all_gather_into_tensor(
output, input, group=self.device_group
)
def _has_aiter_custom_all_gather(self) -> bool:
if self._deterministic_collectives_enabled():
return False
ca_comm = self.ca_comm
return (
ca_comm is not None
and not getattr(ca_comm, "disabled", True)
and hasattr(ca_comm, "should_custom_ag")
and hasattr(ca_comm, "all_gather_reg")
and hasattr(ca_comm, "all_gather_unreg")
)
@staticmethod
def _deterministic_collectives_enabled() -> bool:
if envs.SGLANG_USE_1STAGE_ALLREDUCE.is_set():
return envs.SGLANG_USE_1STAGE_ALLREDUCE.get()
return envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get()
评论区精华
风险与影响
- 风险:
- 平台限制:仅AMD ROCm/HIP平台生效,其他平台不受影响。
- 第三方依赖:需要AITER库及
ca_comm提供should_custom_ag、all_gather_reg、all_gather_unreg等方法;若AITER不可用或通信器类型不支持,自动回退到RCCL。
- dtype限制:AITER仅支持float32/16/bfloat16,int类型会触发回退。PR已在条件中显式检查dtype,避免误入AITER导致崩溃(review中已发现并修复该问题)。
- CUDA图一致性:最终代码清晰区分捕获、piecewise图、warmup三种状态,避免混合使用不同collective导致图无效。
- 回归风险:仅新增分支,现有路径完全保留;但如果
should_custom_ag在非预期形状返回True可能导致性能损失而非错误。
- 影响:
- 用户影响:AMD ROCm用户默认获得all-gather性能提升(典型加速1.13-1.71x),可通过设置
SGLANG_USE_AITER_AG=0禁用。
- 系统影响:无外部接口变更,内部通信路径优化。
- 团队影响:测试和基准脚本可作为后续其他collective优化的模板。
- 风险标记:仅AMD平台, 依赖AITER第三方库, CUDA图一致性已修复, dtype限制需关注
关联脉络
参与讨论