Prhub

#21986 [AMD] Simplify fused allreduce + RMSNorm and remove hidden_dim allowlist

sgl-project/sglang · 作者 hubertlu-tw · 合并时间 2026-04-12 14:47

分析状态 已生成
文件变更 4提交数 4 · 评论 4
代码增减 +263 / -56
amd refactor run-ci consistency

执行摘要

修复 AMD 平台融合 allreduce 阈值并移除 hidden_dim 白名单,简化维护。

PR body明确指出两个问题:1. communicator.py中的激活门使用了<比较符,而AITER内部使用<=,导致在边界大小(如hidden_size=4096、bf16、8192 tokens)时融合路径被错误拒绝;2. parallel_state.py中维护了一个hidden_dim白名单{512, 1024, 2048, 4096}用于1-stage vs 2-stage选择,但这与AITER C++层的检查冗余,增加了每新增模型时的手动维护成本。目标是消除白名单,让AITER的启发式方法自动处理支持性,并确保阈值匹配以避免漏激活。

该PR值得精读,特别是parallel_state.py中移除白名单的设计决策,展示了如何将策略下放至底层内核以简化上层逻辑;同时,测试文件中的残差精度检查函数是验证数值正确性的良好范例,有助于理解融合allreduce的准确性保障。

讨论亮点

review评论中没有具体讨论,但PR body详细阐述了设计决策:移除白名单是因为AITER C++层已通过n % pack_size == 0 && n/pack_size <= 1024检查支持性,Python侧白名单纯属冗余;保留128 KB字节阈值是为了防止大预填充批次触发1-stage内核的硬限制(kMaxBlocks=80 tokens)。结论是依赖下层调度更安全且减少维护,已通过GSM8K准确率测试验证无回归。

实现拆解

  1. 修复communicator.py阈值比较符:在apply_aiter_all_reduce_fusion函数中,将total_bytes < 8 * 1024 * 8192改为total_bytes <= 8 * 1024 * 8192,以匹配AITER内部should_custom_ar使用的<=边界,确保在64 MB阈值处正确激活融合路径。
  2. 移除parallel_state.py白名单并简化逻辑:在fused_allreduce_rmsnorm方法中,删除对hidden_dim in {512, 1024, 2048, 4096}的检查,仅保留total_bytes <= 128 * 1024作为1-stage选择的字节阈值,并移除SGLANG_ENABLE_DETERMINISTIC_INFERENCE的强制1-stage逻辑;更新方法文档注明“ROCm/HIP Only”,强调依赖AITER C++调度。
  3. 增强测试覆盖:在test/registered/ops/test_aiter_allreduce_fusion_amd.py中新增_run_residual_accuracy_check函数,用于分布式验证1-stage/2-stage路径的残差输出比特级准确性,并添加多hidden_dim测试用例(如2880, 4096, 5120等)和基准测试调用。
  4. 调整基准测试配置:更新benchmark/kernels/all_reduce/benchmark_fused_ar_rms_amd.py的--prefill-shapes--decode-shapes默认值,包含新增的hidden_dim(如2880),以在CI中验证多维度性能。
文件 模块 状态 重要度
python/sglang/srt/distributed/parallel_state.py 分布式并行 modified 6.47
test/registered/ops/test_aiter_allreduce_fusion_amd.py 融合测试 modified 7.24
python/sglang/srt/layers/communicator.py 通信层 modified 4.82
benchmark/kernels/all_reduce/benchmark_fused_ar_rms_amd.py 基准测试 modified 4.67
python/sglang/srt/distributed/parallel_state.py core-logic

这是核心调度逻辑文件,移除了 hidden_dim 白名单,简化了 1-stage vs 2-stage 选择,直接影响融合 allreduce 的激活行为。

def fused_allreduce_rmsnorm(
    self,
    input_: torch.Tensor,
    residual_inp_: torch.Tensor,
    weight_: torch.Tensor,
    eps: float,
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
    """Attempt fused all-reduce + RMSNorm via custom all-reduce communicator. ROCm/HIP Only    1-stage vs 2-stage选择:1-stage内核每个token启动一个块,上限为80 tokens(kMaxBlocks)。
    通过字节阈值保护,使大预填充批次回退到2-stage内核,避免运行时错误。
    AITER的C++分发层已控制哪些hidden_dim有有效的1-stage支持,Python侧无需重复检查。
    """
    ca_comm = self.ca_comm
    if ca_comm is None or getattr(ca_comm, "disabled", True):
        return None
​
    # 优先使用communicator原生的融合API
    if hasattr(ca_comm, "fused_allreduce_rmsnorm"):
        try:
            return ca_comm.fused_allreduce_rmsnorm(input_, residual_inp_, weight_, eps)
        except Exception:
            # 回退到custom_fused_ar_rms路径
            pass
​
    if not hasattr(ca_comm, "custom_fused_ar_rms"):
        return None
​
    # 决策逻辑:环境变量覆盖优先,否则基于字节阈值选择1-stage
    if envs.SGLANG_USE_1STAGE_ALLREDUCE.is_set():
        use_1stage_ar = envs.SGLANG_USE_1STAGE_ALLREDUCE.get()
    else:
        total_bytes = input_.numel() * input_.element_size()
        use_1stage_ar = total_bytes <= 128 * 1024 # 仅保留字节阈值,移除hidden_dim白名单
​
    fused_outputs = ca_comm.custom_fused_ar_rms(
        input_,
        residual_inp_,
        weight_,
        eps,
        use_1stage_ar,
    )
    return fused_outputs

关键符号

fused_allreduce_rmsnorm apply_aiter_all_reduce_fusion _run_residual_accuracy_check

评论区精华

移除 hidden_dim 白名单的决策 设计

PR body 中讨论为什么移除白名单:因为 AITER C++ 层已通过 n % pack_size == 0 && n/pack_size <= 1024 检查支持性,Python 侧白名单冗余且增加维护成本。

结论:决定移除白名单,仅保留 128 KB 字节阈值,让 AITER C++ 调度自动处理支持性,这更安全且简化代码。 · 已解决

修复阈值比较符以匹配 AITER 内部逻辑 正确性

发现 communicator.py 中 total_bytes < 8 * 1024 * 8192 与 AITER 的 should_custom_ar 使用 <= 不匹配,导致边界大小被错误拒绝。

结论:将比较符改为 <=,确保在 64 MB 阈值处正确激活融合路径,避免漏激活。 · 已解决

风险与影响

技术风险包括:1. 回归风险:移除白名单后,如果AITER C++层对某些hidden_dim支持不足,可能静默回退到2-stage,但PR提到这是安全的,因为AITER会覆盖unsupported dim;2. 性能风险:字节阈值保留,但off-by-one修复可能使更多小批次激活融合路径,需确保AITER内核性能稳定;3. 兼容性风险:依赖外部AITER PR(#2586和#2453),若未正确集成可能导致数值问题,但测试中添加了残差精度检查以缓解。具体风险点位于parallel_state.py的调度逻辑和communicator.py的阈值比较。

对用户影响:AMD平台用户在使用--enable-aiter-allreduce-fusion时,将更准确地激活融合路径,并支持更多hidden_dim模型而无需手动配置,提升体验和性能。对系统影响:简化了调度逻辑,减少了代码维护负担,使allreduce融合更健壮和自适应。对团队影响:促进了依赖下层组件决策的设计模式,提高了代码可维护性,并通过测试增强确保质量。

核心路径变更 依赖外部组件 移除白名单可能引入兼容性风险

关联 Issue

#2453 Refactor allreduce for supporting prefill case
#2586 Fix: Numerical Accuracy in `allreduce_fusion_kernel_1stage`
#2586 Tiny code cleanup in tokenizer_manager.py

完整报告

执行摘要

  • 一句话:修复AMD平台融合allreduce阈值并移除hidden_dim白名单,简化维护。
  • 推荐动作:该PR值得精读,特别是parallel_state.py中移除白名单的设计决策,展示了如何将策略下放至底层内核以简化上层逻辑;同时,测试文件中的残差精度检查函数是验证数值正确性的良好范例,有助于理解融合allreduce的准确性保障。

功能与动机

PR body明确指出两个问题:1. communicator.py中的激活门使用了<比较符,而AITER内部使用<=,导致在边界大小(如hidden_size=4096、bf16、8192 tokens)时融合路径被错误拒绝;2. parallel_state.py中维护了一个hidden_dim白名单{512, 1024, 2048, 4096}用于1-stage vs 2-stage选择,但这与AITER C++层的检查冗余,增加了每新增模型时的手动维护成本。目标是消除白名单,让AITER的启发式方法自动处理支持性,并确保阈值匹配以避免漏激活。

实现拆解

  1. 修复communicator.py阈值比较符:在apply_aiter_all_reduce_fusion函数中,将total_bytes < 8 * 1024 * 8192改为total_bytes <= 8 * 1024 * 8192,以匹配AITER内部should_custom_ar使用的<=边界,确保在64 MB阈值处正确激活融合路径。
  2. 移除parallel_state.py白名单并简化逻辑:在fused_allreduce_rmsnorm方法中,删除对hidden_dim in {512, 1024, 2048, 4096}的检查,仅保留total_bytes <= 128 * 1024作为1-stage选择的字节阈值,并移除SGLANG_ENABLE_DETERMINISTIC_INFERENCE的强制1-stage逻辑;更新方法文档注明“ROCm/HIP Only”,强调依赖AITER C++调度。
  3. 增强测试覆盖:在test/registered/ops/test_aiter_allreduce_fusion_amd.py中新增_run_residual_accuracy_check函数,用于分布式验证1-stage/2-stage路径的残差输出比特级准确性,并添加多hidden_dim测试用例(如2880, 4096, 5120等)和基准测试调用。
  4. 调整基准测试配置:更新benchmark/kernels/all_reduce/benchmark_fused_ar_rms_amd.py的--prefill-shapes--decode-shapes默认值,包含新增的hidden_dim(如2880),以在CI中验证多维度性能。

关键文件:

  • python/sglang/srt/distributed/parallel_state.py(模块 分布式并行;类别 source;类型 core-logic;符号 fused_allreduce_rmsnorm): 这是核心调度逻辑文件,移除了hidden_dim白名单,简化了1-stage vs 2-stage选择,直接影响融合allreduce的激活行为。
  • test/registered/ops/test_aiter_allreduce_fusion_amd.py(模块 融合测试;类别 test;类型 test-coverage;符号 _run_residual_accuracy_check, test_fused_ar_rms_multi_hidden_dim, test_fused_ar_rms_residual_accuracy, test_fused_ar_rms_benchmark): 测试文件大幅增强,新增残差精度检查函数和多hidden_dim测试,确保移除白名单后的数值正确性和覆盖性。
  • python/sglang/srt/layers/communicator.py(模块 通信层;类别 source;类型 core-logic;符号 apply_aiter_all_reduce_fusion): 修复了AITER融合allreduce激活阈值的off-by-one错误,确保与AITER内部逻辑一致。
  • benchmark/kernels/all_reduce/benchmark_fused_ar_rms_amd.py(模块 基准测试;类别 source;类型 configuration): 更新基准测试的默认形状配置,包含新增的hidden_dim如2880,以在CI中验证多维度性能。

关键符号:fused_allreduce_rmsnorm, apply_aiter_all_reduce_fusion, _run_residual_accuracy_check

关键源码片段

python/sglang/srt/distributed/parallel_state.py

这是核心调度逻辑文件,移除了hidden_dim白名单,简化了1-stage vs 2-stage选择,直接影响融合allreduce的激活行为。

def fused_allreduce_rmsnorm(
    self,
    input_: torch.Tensor,
    residual_inp_: torch.Tensor,
    weight_: torch.Tensor,
    eps: float,
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
    """Attempt fused all-reduce + RMSNorm via custom all-reduce communicator. ROCm/HIP Only    1-stage vs 2-stage选择:1-stage内核每个token启动一个块,上限为80 tokens(kMaxBlocks)。
    通过字节阈值保护,使大预填充批次回退到2-stage内核,避免运行时错误。
    AITER的C++分发层已控制哪些hidden_dim有有效的1-stage支持,Python侧无需重复检查。
    """
    ca_comm = self.ca_comm
    if ca_comm is None or getattr(ca_comm, "disabled", True):
        return None
​
    # 优先使用communicator原生的融合API
    if hasattr(ca_comm, "fused_allreduce_rmsnorm"):
        try:
            return ca_comm.fused_allreduce_rmsnorm(input_, residual_inp_, weight_, eps)
        except Exception:
            # 回退到custom_fused_ar_rms路径
            pass
​
    if not hasattr(ca_comm, "custom_fused_ar_rms"):
        return None
​
    # 决策逻辑:环境变量覆盖优先,否则基于字节阈值选择1-stage
    if envs.SGLANG_USE_1STAGE_ALLREDUCE.is_set():
        use_1stage_ar = envs.SGLANG_USE_1STAGE_ALLREDUCE.get()
    else:
        total_bytes = input_.numel() * input_.element_size()
        use_1stage_ar = total_bytes <= 128 * 1024 # 仅保留字节阈值,移除hidden_dim白名单
​
    fused_outputs = ca_comm.custom_fused_ar_rms(
        input_,
        residual_inp_,
        weight_,
        eps,
        use_1stage_ar,
    )
    return fused_outputs

评论区精华

review评论中没有具体讨论,但PR body详细阐述了设计决策:移除白名单是因为AITER C++层已通过n % pack_size == 0 && n/pack_size <= 1024检查支持性,Python侧白名单纯属冗余;保留128 KB字节阈值是为了防止大预填充批次触发1-stage内核的硬限制(kMaxBlocks=80 tokens)。结论是依赖下层调度更安全且减少维护,已通过GSM8K准确率测试验证无回归。

  • 移除hidden_dim白名单的决策 (design): 决定移除白名单,仅保留128 KB字节阈值,让AITER C++调度自动处理支持性,这更安全且简化代码。
  • 修复阈值比较符以匹配AITER内部逻辑 (correctness): 将比较符改为<=,确保在64 MB阈值处正确激活融合路径,避免漏激活。

风险与影响

  • 风险:技术风险包括:1. 回归风险:移除白名单后,如果AITER C++层对某些hidden_dim支持不足,可能静默回退到2-stage,但PR提到这是安全的,因为AITER会覆盖unsupported dim;2. 性能风险:字节阈值保留,但off-by-one修复可能使更多小批次激活融合路径,需确保AITER内核性能稳定;3. 兼容性风险:依赖外部AITER PR(#2586和#2453),若未正确集成可能导致数值问题,但测试中添加了残差精度检查以缓解。具体风险点位于parallel_state.py的调度逻辑和communicator.py的阈值比较。
  • 影响:对用户影响:AMD平台用户在使用--enable-aiter-allreduce-fusion时,将更准确地激活融合路径,并支持更多hidden_dim模型而无需手动配置,提升体验和性能。对系统影响:简化了调度逻辑,减少了代码维护负担,使allreduce融合更健壮和自适应。对团队影响:促进了依赖下层组件决策的设计模式,提高了代码可维护性,并通过测试增强确保质量。
  • 风险标记:核心路径变更, 依赖外部组件, 移除白名单可能引入兼容性风险

关联脉络

  • PR #21947 [AMD] Add 2880 to hidden_dim allowlist for fused allreduce: 该PR将2880添加到hidden_dim白名单,但被当前PR取代,因为当前PR移除了整个白名单,使所有AITER支持的hidden_dim自动工作。

参与讨论