Prhub

#23572 [Diffusion][NPU][Bugfix] Ascend_fa crashes when sequence parallelism is used.

原始 PR 作者 Napkin-AI 合并时间 2026-04-24 00:21 文件变更 1 提交数 2 评论 2 代码增减 +3 / -4

执行摘要

修复 AscendFA 在序列并行时头数参数不正确

修复AscendFlashAttention后端在使用Ulysses序列并行(ulysses-degree > 1)时的崩溃。用户给出的复现命令中sp-degree=4(tp-size=2,num-gpus=8),运行时出现NPU函数调用失败错误,原因是query的shape[1,5,25200,128]与预期[1,20,25200,128]不匹配——即传入的num_heads参数是用全局头数(20)而非切分后的局部头数(20/4=5)。

建议合并。这是一个定位准确、改动量小的bug修复,修复了NPU AscendFA后端在序列并行下的关键崩溃。review建议已被采纳,代码清晰。后续可考虑添加单元测试覆盖sp-degree>1的场景。

讨论亮点

reviewer gemini-code-assist[bot] 指出:在forward中query.shape[1]同时用于获取seq_len和(在transpose前的版本中)head counts,两个不同含义使用同一索引可能造成维护混淆。虽然当前实现正确,但建议将head counts捕获到局部变量以提高可读性。

后续修改采纳了建议:最终实现中已改为在forward开始处(transpose前)通过query.shape[2]key.shape[2]获取头数,与seq_len(query.shape[1])分开,解决了review中提出的混淆问题。

结论:review建议被采纳,最终提交中正确分离了seq_len和头数的获取。

实现拆解

  1. 删除__init__中的头数缓存ascend_fa.py第70-71行):移除self.num_heads = num_headsself.num_kv_heads = num_kv_heads or num_heads,因为这些值在调用forward时可能因序列并行而不再有效。

  2. 在forward中动态获取头数ascend_fa.py第80行):在forward方法开始时,通过query.shape[2]key.shape[2]获取当前输入张量的实际头数。由于在transpose(1,2)之前,第2维(索引2)正好对应头数维度(BNSD布局中N的位置),此时query和key已经由上层调度根据sp-degree切分,因此取到的就是局部头数。

  3. 更新NPU调用参数ascend_fa.py第94-95行):将原本的num_heads=self.num_headsnum_key_value_heads=self.num_kv_heads改为使用步骤2中动态获取的num_headsnum_key_value_heads变量,确保传入NPU API的实际头数与张量匹配。

  4. 配套调整:无测试文件变更,仅修改核心逻辑,共+3/-4行。

文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/layers/attention/backends/ascend_fa.py 注意力后端 modified 5.39

关键符号

AscendFAImpl.__init__ AscendFAImpl.forward

关键源码片段

python/sglang/multimodal_gen/runtime/layers/attention/backends/ascend_fa.py core-logic

核心变更文件,修复 AscendFA 后端在序列并行下头数参数错误导致的崩溃,从缓存值改为动态获取。

# 仅展示修改后的关键部分,省略未修改的 imports 和基类
class AscendFAImpl(AttentionImpl):
​
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        causal: bool,
        softmax_scale: float,
        num_kv_heads: int | None = None,
        prefix: str = "",
        **extra_impl_args,
    ) -> None:
        self.causal = causal
        self.softmax_scale = softmax_scale
        # 注意:不再缓存 num_heads/num_kv_heads,
        # 因为在序列并行(ulysses)时,
        # 构造时传入的全局头数与 forward 时实际的局部头数不一致。
​
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attn_metadata: AttentionMetadata,
        return_softmax_lse: bool = False,
    ) -> torch.Tensor:
        # 在 transpose 之前(BNSD 布局,第 2 维是头数),
        # 从输入张量动态获取实际头数,确保序列并行时参数正确。
        num_heads = query.shape[2]
        num_key_value_heads = key.shape[2]
        mask = None
        if self.causal:
            seq_len = query.shape[1]
            mask = torch.triu(
                torch.ones(seq_len, seq_len, device=query.device), diagonal=1
            ).bool()
        # transpose to bs, heads, seq_len, head_dim
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)
        output, lse = torch.ops.npu.npu_fused_infer_attention_score(
            query,
            key,
            value,
            num_heads=num_heads,
            num_key_value_heads=num_key_value_heads,
            scale=self.softmax_scale,
            input_layout="BNSD",
            softmax_lse_flag=return_softmax_lse,
            atten_mask=mask,
        )
        output = output.transpose(1, 2)
        if return_softmax_lse:
            return output, lse
        return output

评论区精华

使用 query.shape[1] 同时表示 seq_len 和 head counts 的混淆 正确性

reviewer 指出在初始实现中,query.shape[1] 在 forward 方法中既用于获取 seq_len(第 81 行),又用于获取头数(第 92-93 行),同一索引代表不同含义可能造成维护混淆。

结论:最终实现中使用 query.shape[2] 和 key.shape[2] 获取头数,在 transpose 之前操作,避免了与 seq_len 的冲突,提高了可读性和健壮性。 · 已解决

风险与影响

  1. 回归风险低:变更仅影响AscendFA后端,且逻辑简单——从缓存值改为动态取值。在非序列并行场景下,query.shape[2]应与缓存的num_heads相同(因为Ulysses degree=1时不切分),因此行为不变。
  2. 性能影响:无,仅在forward中多了两次shape读取,开销可忽略。
  3. 兼容性:仅影响NPU + AscendFA后端,不影响其他后端。
  4. 测试覆盖缺失:PR未添加对应单元测试。考虑到这是针对特定sp-degree配置的bugfix,若后续重构该模块或变更forward参数顺序,可能再次引入类似问题。
  1. 用户影响:修复了NPU用户在使用Ulysses序列并行(如sp-degree=4,tp-size=2,num-gpus=8)时运行扩散模型(如Wan2.2-T2V)的崩溃问题,影响面局限在NPU + AscendFA + 序列并行组合场景。
  2. 系统影响:无,代码改动量小(+3/-4行),不涉及其他模块。
  3. 团队影响:对于维护NPU后端的开发人员,这是一个清晰的bug修复示例,说明序列并行下构造缓存与运行时实际参数可能不一致的风险。
缺少测试覆盖 核心路径变更

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论