Prhub

#24654 [Bugfix] [DSA] [Hisparse] Broadcast TP Rank 0 Topk Indexes to other TPs

原始 PR 作者 xz-keg 合并时间 2026-05-29 12:14 文件变更 3 提交数 23 评论 13 代码增减 +64 / -23

执行摘要

广播 TP rank 0 的 top-k 索引,修复 DSA/HiSparse 跨 TP 不一致

PR body 指出当前 topk.cu 实现不稳定(相同输入输出可能不同),导致不同 TP rank 选择的 top-k 不一致。对于 DSA 模型,这种不一致虽然对最终精度影响不大(因为备选 top-k 权重低),但对于 HiSparse 实现和策略开发是致命的,因为索引选择还影响不同 GPU 上的 buffer。关联测试 PR #24819 验证了该问题。评论中 DarkSharpness 确认 topk 不稳定,尤其在 RL 中可能致命。

该 PR 是解决 DSA top-k 不确定性的重要临时补丁,对使用 HiSparse 或进行确定性推理的用户值得关注。设计上采用环境变量开关、兼容 CUDA 图捕获的做法值得学习。建议阅读 dsa_indexer.py 中的广播实现和性能分析部分。

讨论亮点
  • xiezhq-hermann 提问:为什么 HiSparse 受影响更大?xz-keg 回答:因为索引选择影响 GPU 上的 buffer,不仅影响注意力权重。
  • DarkSharpness 确认:当前 topk 算法不稳定,尤其在 RL 中可能致命。广播是临时方案,未来需要实现更稳定的 topk 变体。
  • 性能权衡讨论:xz-keg 提出 partial topk + broadcast 的可能性,但 DarkSharpness 指出 page allocation 非确定性和通信开销仍是瓶颈。最终决定采用从 rank 0 广播的简单方案。

实现拆解

  1. 新增环境变量:在 python/sglang/srt/environ.py 中添加 SGLANG_DSA_TOPK_BROADCAST,类型 EnvBool,默认 False
  2. 实现广播逻辑:在 dsa_indexer.py 中新增三个函数:
    • _broadcast_indexer_topk_from_rank0_impl:实际的广播实现,使用 get_attn_tp_group() 获取通信组,对 CUDA 图捕获路径使用 PyNCCL,普通路径使用 group.broadcast
    • broadcast_indexer_topk_from_rank0_:用 @register_custom_op@register_split_op 装饰的算子,供 CUDA 图分段捕获时调用。
    • _broadcast_indexer_topk_from_rank0:对外包装函数,检查环境变量并选择路径。
  3. 修改 forward_cuda:在 skip_logits_computation 路径和空结果路径中,在 maybe_capture_indexer_topk 之前插入 _broadcast_indexer_topk_from_rank0 同步 top-k 结果。注意同步必须在最终输出上进行,避免内部 chunked topk_transform 导致集合操作分歧。
  4. 文档更新:在 docs_new/docs/references/environment_variables.mdx 中添加环境变量说明。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/dsa/dsa_indexer.py 注意力索引 modified 8.04
python/sglang/srt/environ.py 环境配置 modified 4.35
docs_new/docs/references/environment_variables.mdx 文档 modified 2.38

关键符号

broadcast_indexer_topk_from_rank0_ _broadcast_indexer_topk_from_rank0_impl _broadcast_indexer_topk_from_rank0

关键源码片段

python/sglang/srt/layers/attention/dsa/dsa_indexer.py core-logic

主要实现文件,新增广播 topk 的注册算子和调用逻辑,修改 forward_cuda 集成广播。

# 文件 : python/sglang/srt/layers/attention/dsa/dsa_indexer.py# 新增 import
from sglang.srt.distributed import (
    get_attn_context_model_parallel_rank,
    get_attn_context_model_parallel_world_size,
    get_attn_tp_group, # 新增
)# 注册自定义算子,用于 CUDA 图捕获路径
@register_custom_op(mutates_args=["topk_indices"])
@register_split_op()
def broadcast_indexer_topk_from_rank0_(topk_indices: torch.Tensor) -> None:
    """CUDA 图捕获时调用的广播算子包装。"""
    _broadcast_indexer_topk_from_rank0_impl(topk_indices)# 实际的广播实现
def _broadcast_indexer_topk_from_rank0_impl(topk_indices: torch.Tensor) -> None:
    """将 topk_indices 从 TP rank 0 广播到其他 rank。"""
    group = get_attn_tp_group()
    if group.world_size == 1:
        return
​
    if topk_indices.device.type == "cuda" and torch.cuda.is_current_stream_capturing():
        # CUDA 图捕获中必须使用 PyNCCL 才能被记录
        if group.pynccl_comm is None:
            raise RuntimeError(
                "SGLANG_DSA_TOPK_BROADCAST requires PyNCCL during CUDA graph capture."
            )
        with group.pynccl_comm.change_state(enable=True):
            group.pynccl_comm.broadcast(topk_indices, src=0)
    else:
        # 正常执行使用 group.broadcast
        group.broadcast(topk_indices, src=0)# 对外调用的包装函数,检查环境变量并选择路径
def _broadcast_indexer_topk_from_rank0(
    topk_indices: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
    """同步仅最终 indexer 输出,避免内部 topk_transform 的 chunk 分歧。"""
    if topk_indices is None or not envs.SGLANG_DSA_TOPK_BROADCAST.get():
        return topk_indices
​
    if is_in_piecewise_cuda_graph():
        # 分段 CUDA 图中使用自定义算子
        broadcast_indexer_topk_from_rank0_(topk_indices)
    else:
        _broadcast_indexer_topk_from_rank0_impl(topk_indices)
    return topk_indices# 在 forward_cuda 的 skip_logits_computation 分支中调用
# 原代码 : return maybe_capture_indexer_topk(layer_id, self._forward_cuda_k_only(...))
# 修改后 :
if skip_logits_computation and (not self.dsa_enable_prefill_cp):
    topk_result = self._forward_cuda_k_only(
        x, positions, forward_batch, layer_id,
        act_quant, enable_dual_stream, metadata, return_indices,
    )
    # 同步排名 0 的 topk 索引(如果启用)
    topk_result = _broadcast_indexer_topk_from_rank0(topk_result)
    return maybe_capture_indexer_topk(layer_id, topk_result)

评论区精华

为什么 HiSparse 受影响更大 question

xiezhq-hermann 提问为何 HiSparse 比普通 DSA 更受 topk 不一致影响。xz-keg 解释:因为索引选择还影响 GPU 上的 buffer,而不仅限于注意力权重。

结论:xiezhq-hermann 接受解释,问题解决。 · 已解决

广播是临时方案,未来需要稳定 topk 设计

DarkSharpness 确认 topk 不稳定,尤其在 RL 中可能致命,广播是临时工作区,未来需要实现稳定的 topk 变体。xz-keg 提到 RL 中已用 torch.topk 替代,但成本高。

结论:一致同意广播为临时方案,未来需实现确定性 topk 算法。 · 已解决

partial topk + broadcast 的可行性讨论 性能

xz-keg 提出每个 TP rank 计算部分 topk 再互相广播的改进思路。DarkSharpness 指出 page allocation 非确定性可能导致致命错误,且通信开销非平凡(topk kernel 可达 30% HBM 带宽,远超 NVLink 带宽)。

结论:暂时接受简单广播方案,未来再探索更优方案。 · 已解决

风险与影响

  1. 性能风险:开启后吞吐下降约 3%(来自 NCCL broadcast 开销),在长序列场景可能更明显。
  2. CUDA 图兼容风险:CUDA 图捕获时需要 PyNCCL,若 pynccl_commNone 会抛出 RuntimeError。
  3. 回归风险:默认关闭,不影响现有行为;但广播路径下如果 group 通信失败可能导致无声错误。
  4. 可维护性风险:引入了代码分支和临时方案,未来稳定 topk 实现后需清理。

用户:使用 HiSparse 且需要跨 TP 确定性的用户可受益;其他用户无影响(默认关闭)。系统:增加少量同步通信操作和代码分支。团队:需维护两种模式,未来计划替换为稳定 topk。

通信开销 CUDA 图兼容性 临时方案

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论