Prhub

#21383 [diffusion] [NPU] support ring attention on NPU with FA

原始 PR 作者 Makcum888e 合并时间 2026-03-31 01:10 文件变更 5 提交数 3 评论 13 代码增减 +188 / -2

执行摘要

新增 NPU Ascend Flash Attention 后端,支持 ring attention。

PR body中说明:'ring attention requires return_softmax_lse but sdpa backend cannot support this option.' 因此,需要为NPU实现一个支持该选项的新attention backend,使用torch.ops.npu.npu_fused_infer_attention_score,作为PR #20248的另一种方案。

建议技术管理者关注此PR,以了解如何为不同平台添加定制attention backend的架构模式。工程师可精读ascend_fa.py中的实现,学习如何集成硬件专用操作并遵循抽象基类设计,以及通过review讨论了解代码优化点。

讨论亮点

review评论中,gemini-code-assist[bot]指出了多个关键问题:AscendFABackend.get_metadata_cls方法应实现而非抛出NotImplementedError;文档措辞需更清晰以提升可读性;为保持代码一致性,建议将metadata类重命名为AscendFAMetadata;修复类型提示和移除未使用参数(如head_sizeattn_metadata)。所有建议均被采纳,作者Makcum888e回应“done”,审核者ping1jing2批准,表明讨论已解决。

实现拆解

使用Markdown按4个步骤拆解实现过程:

  1. 新增AscendFA后端实现:在python/sglang/multimodal_gen/runtime/layers/attention/backends/ascend_fa.py中创建AscendFAMetadataAscendFAMetadataBuilderAscendFABackendAscendFAImpl类。关键符号AscendFAImpl.forward调用torch.ops.npu.npu_fused_infer_attention_score以支持return_softmax_lse,为ring attention提供基础。
  2. 集成到NPU平台:修改python/sglang/multimodal_gen/runtime/platforms/npu.py中的get_attn_backend_cls_str方法,当selected_backendAttentionBackendEnum.FA时返回AscendFABackend类路径,实现后端选择逻辑。
  3. 更新测试配套:在python/sglang/multimodal_gen/test/server/ascend/perf_baselines_npu.json中添加qwen_image_t2i_2npu用例的性能基准数据,并在python/sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py中添加相应测试配置,验证ring attention功能。
  4. 文档同步:更新docs/diffusion/performance/attention_backends.md,修正NPU平台对FA的支持描述和兼容性表格,确保用户文档准确。
文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/layers/attention/backends/ascend_fa.py 注意力后端 added 8.59
python/sglang/multimodal_gen/runtime/platforms/npu.py 平台集成 modified 5.47
python/sglang/multimodal_gen/test/server/ascend/perf_baselines_npu.json 性能测试 modified 4.89
python/sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py 测试配置 modified 4.33
docs/diffusion/performance/attention_backends.md 文档 modified 2.3

关键符号

AscendFAImpl.forward AscendFABackend.get_enum AscendFABackend.get_metadata_cls AscendFAMetadataBuilder.build NPUPlatform.get_attn_backend_cls_str

关键源码片段

python/sglang/multimodal_gen/runtime/layers/attention/backends/ascend_fa.py core-logic

新增 AscendFA 后端实现,是支持 ring attention 的核心逻辑文件,包含关键类和 forward 方法。

from dataclasses import dataclass
from typing import Any
import torch
from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
    AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataBuilder
)
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum@dataclass
class AscendFAMetadata:
    pass # 元数据类,当前无需额外字段class AscendFAMetadataBuilder(AttentionMetadataBuilder):
    def __init__(self) -> None:
        pass
    def prepare(self) -> None:
        pass
    def build(self, **kwargs: Any) -> AttentionMetadata:
        return AscendFAMetadata() # 构建并返回元数据实例class AscendFABackend(AttentionBackend):
    @staticmethod
    def get_enum() -> AttentionBackendEnum:
        return AttentionBackendEnum.FA # 返回后端枚举标识
    @staticmethod
    def get_impl_cls() -> type["AscendFAImpl"]:
        return AscendFAImpl # 返回实现类
    @staticmethod
    def get_metadata_cls() -> type["AttentionMetadata"]:
        return AscendFAMetadata # 实现元数据类返回,修复原 NotImplementedError
    @staticmethod
    def get_builder_cls() -> type["AttentionMetadataBuilder"]:
        return AscendFAMetadataBuilder # 返回构建器类class AscendFAImpl(AttentionImpl):
    def __init__(self, num_heads: int, head_size: int, causal: bool, softmax_scale: float,
                 num_kv_heads: int | None = None, prefix: str = "", **extra_impl_args) -> None:
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads or num_heads
        # 注意:head_size 和 prefix 参数未使用,保留以保持接口兼容性
    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                attn_metadata: AttentionMetadata, return_softmax_lse: bool = False) -> torch.Tensor:
        mask = None
        if self.causal:
            seq_len = query.shape[1]
            mask = torch.triu(torch.ones(seq_len, seq_len, device=query.device), diagonal=1).bool() # 构建因果掩码
        query = query.transpose(1, 2) # 转置为 BSHD 布局
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)
        output, lse = torch.ops.npu.npu_fused_infer_attention_score(
            query, key, value, num_heads=self.num_heads,
            num_key_value_heads=self.num_kv_heads, scale=self.softmax_scale,
            input_layout="BNSD", softmax_lse_flag=return_softmax_lse, atten_mask=mask
        ) # 调用 NPU 专用 fused attention 操作,支持返回 softmax LSE
        output = output.transpose(1, 2) # 转置回 BSHD 布局
        if return_softmax_lse:
            return output, lse # 返回输出和 LSE,支持 ring attention
        return output

评论区精华

修复 NotImplementedError in get_metadata_cls 正确性

gemini-code-assist[bot] 指出 AscendFABackend.get_metadata_cls 方法抛出 NotImplementedError 会导致运行时错误,应返回 FlashAttentionMetadata(后重命名为 AscendFAMetadata)。

结论:作者 Makcum888e 修复为返回 AscendFAMetadata,审核者 ping1jing2 认可。 · 已解决

改进文档措辞清晰度 documentation

gemini-code-assist[bot] 建议重写文档中 NPU 描述,使其更易读,例如从 'for ring attention uses FA otherwise uses PyTorch SDPA' 改为更清晰表述。

结论:文档更新采纳建议,提升可读性。 · 已解决

代码一致性和未使用参数优化 设计

gemini-code-assist[bot] 建议为保持一致性重命名 metadata 类,并修复类型提示和移除未使用参数(如 head_size、attn_metadata),以提升代码质量。

结论:作者实施重命名和参数清理,审核者批准。 · 已解决

风险与影响

技术风险具体包括:新backend依赖NPU特定操作torch.ops.npu.npu_fused_infer_attention_score,在其他平台不可用可能导致兼容性问题;AscendFAImpl.__init__中未使用的参数(如head_size)可能影响未来扩展性和代码清晰度;测试覆盖虽新增性能基准,但新backend的完整功能验证依赖于NPU硬件环境,可能存在环境特定问题。

对用户而言,现在可以在NPU上使用ring attention,可能提升diffusion模型的推理性能和效率。系统层面,扩展了attention backend生态系统,增强了多平台支持,但增加了NPU特定代码的维护负担。团队需要确保新代码与现有backend接口兼容,并可能影响后续NPU相关开发。

NPU 依赖风险 未使用参数 测试覆盖有限

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论