Prhub

#23482 [Diffusion][NPU]Add attention backends for diffusion models for Ascend NPU

原始 PR 作者 Napkin-AI 合并时间 2026-05-19 17:46 文件变更 12 提交数 12 评论 42 代码增减 +1157 / -19

执行摘要

NPU 扩散模型新增三种注意力后端

PR 描述指出 'Only SDPA and FIA backends available for Ascend NPU',新增 Laser、Block Sparse 和 Rain Fusion Attention 可提供更灵活的稀疏注意力方案,提高扩散模型的推理效率。

值得精读,尤其是如何在现有注意力抽象框架中新增后端,以及平台选择逻辑的 try-except 降级处理。对于 NPU 相关开发者,Rain Fusion 和 Block Sparse 的实现细节需关注 review 中的性能优化建议。

讨论亮点

讨论集中在代码质量、性能优化和依赖影响上。ping1jing2 担心 import attentions 会破坏 GPU CI,Napkin-AI 回应此库是 sgl-kernel-npu 的一部分且当前测试未使用,不会造成影响。对于 _avgpool_get_mask_index 函数的性能,ping1jing2 建议使用 F.pad 替代 torch.zeros+cat、避免 torch.einsum,Napkin-AI 认为当前版本已较优但部分接受。BSND 布局支持问题被提出,Napkin-AI 确认内核有精度问题,暂时维持转置方案并添加 TODO。

实现拆解

  1. 新增注意力后端类:在 backends/ 下创建三个文件,每个实现 AttentionBackendAttentionImpl 及可选的 AttentionMetadata/AttentionMetadataBuilder。Laser Attention 封装 torch.ops.attentions.la,Block Sparse 封装 torch.ops.attentions.sparse_block_attention,Rain Fusion 组合 Laser 与自定义池化/索引。
  2. 扩展平台选择逻辑:在 npu.pyget_attn_backend_cls_str 中添加三个 elif 分支,通过 try-import 检查依赖,返回对应后端类路径,缺失时抛出清晰安装指引。
  3. 更新枚举与配置:在 interface.pyAttentionBackendEnum 中添加三个新枚举,并在 is_sparse 中标记。在 configs/models/adapter/base.pydits/base.py 的支持列表中加入新枚举。
  4. 集成到扩散流水线:在 denoising.py_build_attn_metadata 中为 Block Sparse 和 Rain Fusion 构建元数据(当前时间步、稀疏度、潜空间形状等),Laser 无元数据。
  5. 更新文档:在 docs_newdocs 目录下的注意力后端文档和兼容性矩阵中添加新后端说明,包括 CLI 参数示例和依赖安装方式。
文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/layers/attention/backends/rain_fusion_attn.py 注意力后端 added 9.18
python/sglang/multimodal_gen/runtime/layers/attention/backends/block_sparse_attn.py 注意力后端 added 9.08
python/sglang/multimodal_gen/runtime/layers/attention/backends/laser_attn.py 注意力后端 added 8.9
python/sglang/multimodal_gen/runtime/platforms/npu.py 平台层 modified 6.91
python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py 流水线 modified 6.61
python/sglang/multimodal_gen/runtime/platforms/interface.py 配置枚举 modified 5.19
python/sglang/multimodal_gen/configs/models/adapter/base.py 配置 modified 4.63
python/sglang/multimodal_gen/configs/models/dits/base.py 配置 modified 4.63
docs_new/docs/sglang-diffusion/attention_backends.mdx 文档 modified 4.25
docs_new/docs/sglang-diffusion/compatibility_matrix.mdx 文档 modified 3.85
docs/diffusion/compatibility_matrix.md 文档 modified 3.01
docs/diffusion/performance/attention_backends.md 文档 modified 2.8

关键符号

RainFusionAttentionBackend.get_supported_head_sizes RainFusionAttentionMetadataBuilder.build BlockSparseAttentionBackend.get_metadata_cls BlockSparseAttentionMetadataBuilder.build LaserAttentionBackend.get_enum LaserAttentionImpl._la_preprocess_input NPUPlatform.get_attn_backend_cls_str DenoisingStage._build_attn_metadata

关键源码片段

python/sglang/multimodal_gen/runtime/layers/attention/backends/rain_fusion_attn.py dependency-wiring

核心新增文件:实现 Rain Fusion Attention 后端,包含完整的 Backend、Metadata、MetadataBuilder 和 Impl 类,通过激光注意力与自定义池化 / 索引融合提高稀疏注意力性能。

import math
from dataclasses import dataclass
from typing import Any, List, Optionalimport attentions # noqa: F401
import torch
from einops import rearrangefrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
    AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataBuilder,
)
from sglang.multimodal_gen.runtime.layers.attention.backends.laser_attn import LaserAttentionBackend
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.runtime.utils.logging_utils import init_loggerlogger = init_logger(__name__)
​
​
class RainFusionAttentionBackend(AttentionBackend):
    """Rain Fusion 注意力后端,封装了融合池化与激光注意力的稀疏方案"""
    accept_output_buffer: bool = True
​
    @staticmethod
    def get_supported_head_sizes() -> list[int]:
        return [32, 64, 96, 128]
​
    @staticmethod
    def get_enum() -> AttentionBackendEnum:
        return AttentionBackendEnum.RAIN_FUSION_ATTN
​
    @staticmethod
    def get_impl_cls() -> type["RainFusionAttentionImpl"]:
        return RainFusionAttentionImpl
​
    @staticmethod
    def get_metadata_cls() -> type["RainFusionAttentionMetadata"]:
        return RainFusionAttentionMetadata
​
    @staticmethod
    def get_builder_cls() -> type["RainFusionAttentionMetadataBuilder"]:
        return RainFusionAttentionMetadataBuilder
​
​
@dataclass
class RainFusionAttentionMetadata(AttentionMetadata):
    current_timestep: int
    skip_first_steps: int
    sparsity: float
    latent_shape: list[int]
​
​
class RainFusionAttentionMetadataBuilder(AttentionMetadataBuilder):
    """构建 Rain Fusion 所需的元数据:验证稀疏度与步数,计算 latent 形状"""
    def __init__(self) -> None:
        pass
​
    def prepare(self) -> None:
        pass
​
    def build(self, current_timestep: int, skip_first_steps: int, sparsity: float,
              raw_latent_shape: list[int], patch_size: tuple[int, int, int],
              **kwargs: dict[str, Any]) -> RainFusionAttentionMetadata:
        if not (skip_first_steps >= 0 and 0.0 <= sparsity < 1.0):
            raise ValueError(
                f"Sparsity 应在 [0, 1) 范围,skip_first_steps 应非负,"
                f"当前:sparsity={sparsity}, skip_first_steps={skip_first_steps}"
            )
        if sparsity == 0.0:
            logger.warning("稀疏度为 0,建议使用 Laser Attention 或增加稀疏度以提升性能。")
        # 将 raw_latent_shape 按 patch_size 分割得到 latent 形状
        latent_shape = raw_latent_shape[-3:]
        latent_shape = [latent_shape[i] // patch_size[i] for i in range(3)]
        return RainFusionAttentionMetadata(
            current_timestep=current_timestep,
            skip_first_steps=skip_first_steps,
            sparsity=sparsity,
            latent_shape=latent_shape,
        )
python/sglang/multimodal_gen/runtime/layers/attention/backends/block_sparse_attn.py dependency-wiring

核心新增文件:实现 Block Sparse Attention 后端,支持块级稀疏注意力,通过元数据构建器将时间步、稀疏度与潜空间块对齐。

from dataclasses import dataclass
from typing import Anyimport attentions # noqa: F401
import torchfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
    AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataBuilder,
)
from sglang.multimodal_gen.runtime.layers.attention.backends.laser_attn import LaserAttentionBackend
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.runtime.utils.logging_utils import init_loggerlogger = init_logger(__name__)
BSA_BLOCK_SIZE = 128 # 块稀疏的块大小
​
​
class BlockSparseAttentionBackend(AttentionBackend):
    """Block Sparse Attention 后端,通过块级掩码实现稀疏注意力"""
    accept_output_buffer: bool = True
​
    @staticmethod
    def get_supported_head_sizes() -> list[int]:
        return [32, 64, 96, 128]
​
    @staticmethod
    def get_enum() -> AttentionBackendEnum:
        return AttentionBackendEnum.BLOCK_SPARSE_ATTN
​
    @staticmethod
    def get_impl_cls() -> type["BlockSparseAttentionImpl"]:
        return BlockSparseAttentionImpl
​
    @staticmethod
    def get_metadata_cls() -> type["BlockSparseAttentionMetadata"]:
        return BlockSparseAttentionMetadata
​
    @staticmethod
    def get_builder_cls() -> type["BlockSparseAttentionMetadataBuilder"]:
        return BlockSparseAttentionMetadataBuilder
​
​
@dataclass
class BlockSparseAttentionMetadata(AttentionMetadata):
    current_timestep: int
    skip_first_steps: int
    sparsity: float
    block_frame_stride: int
​
​
class BlockSparseAttentionMetadataBuilder(AttentionMetadataBuilder):
    """构建元数据:校验参数,计算 block_frame_stride 以保持首帧非稀疏"""
    def __init__(self) -> None:
        pass
​
    def prepare(self) -> None:
        pass
​
    def build(self, current_timestep: int, skip_first_steps: int, sparsity: float,
              raw_latent_shape: list[int], patch_size: tuple[int, int, int],
              **kwargs: dict[str, Any]) -> BlockSparseAttentionMetadata:
        if not (skip_first_steps >= 0 and 0.0 <= sparsity < 1.0):
            raise ValueError(
                f"Sparsity 应在 [0,1) 范围, skip_first_steps 应非负,"
                f"当前:sparsity={sparsity}, skip_first_steps={skip_first_steps}"
            )
        if sparsity == 0.0:
            logger.warning(
                "稀疏度为 0 将不丢弃任何 token,"
                "建议使用 Laser Attention 或增加稀疏度。"
            )
        # 从 raw_latent_shape 提取高宽,除以 patch_size 得到 block 数
        if len(raw_latent_shape) >= 5:
            latent_height, latent_width = raw_latent_shape[3:5]
        else:
            latent_height, latent_width = raw_latent_shape[-2:]
        latent_height //= patch_size[1]
        latent_width //= patch_size[2]
        frame_stride = latent_height * latent_width
        # 将 frame_stride 按 BSA_BLOCK_SIZE 对齐
        block_frame_stride = (frame_stride + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
        return BlockSparseAttentionMetadata(
            current_timestep=current_timestep,
            skip_first_steps=skip_first_steps,
            sparsity=sparsity,
            block_frame_stride=block_frame_stride,
        )
python/sglang/multimodal_gen/runtime/layers/attention/backends/laser_attn.py dependency-wiring

核心新增文件:实现 Laser Attention 后端,封装 sgl-kernel-npu 的 la 内核,作为块稀疏和雨融合的基础依赖。

import torchfrom sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
    AttentionBackend, AttentionImpl, AttentionMetadata,
)
from sglang.multimodal_gen.runtime.layers.attention.backends.sdpa import SDPABackend
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger# 导入 torch.ops.attentions,该模块由 sgl_kernel_npu 提供
try:
    import attentions # noqa: F401
except ImportError as e:
    raise ImportError(
        "需要安装 sgl_kernel_npu 以提供 attentions 包"
    ) from elogger = init_logger(__name__)
​
​
class LaserAttentionBackend(AttentionBackend):
    """Laser Attention 后端封装,无元数据构建需求"""
    accept_output_buffer: bool = True
​
    @staticmethod
    def get_supported_head_sizes() -> list[int]:
        return [32, 64, 96, 128]
​
    @staticmethod
    def get_enum() -> AttentionBackendEnum:
        return AttentionBackendEnum.LASER_ATTN
​
    @staticmethod
    def get_impl_cls() -> type["LaserAttentionImpl"]:
        return LaserAttentionImpl
​
​
class LaserAttentionImpl(AttentionImpl):
    """Laser Attention 实现,通过 torch.ops.attentions.la 调用 NPU 内核"""
    def __init__(self, num_heads: int, head_size: int, causal: bool, softmax_scale: float,
                 num_kv_heads: int | None = None, prefix: str = "", **extra_impl_args) -> None:
        self.softmax_scale = softmax_scale
        # 预处理后张量布局预期为 BNSD
        self.seqlen_base = 256
        self.seqlen_index = 2
        self.dim_index = 3
        self.dim_base = 128
        self.max_token = 2**31 - 1
        self.seq_len_pad_base = 256
        # 激光注意力算子在短序列上存在问题,低于 2048 时使用 SDPA 回退
        self.min_seqlen = 2048
        self.sdpa_impl = SDPABackend.get_impl_cls()(
            num_heads, head_size, causal, softmax_scale,
            num_kv_heads, prefix, **extra_impl_args,
        )

评论区精华

导入 attentions 库对 GPU CI 的影响 question

ping1jing2 指出 `import attentions` 可能破坏 GPU CI,Napkin-AI 回应此库是 sgl-kernel-npu 的一部分,且当前测试未使用这些后端,不会影响 CI。

结论:确认不影响现有 CI,无需修改 pyproject.toml。 · 已解决

使用 F.pad 优化 torch.zeros+cat 性能

ping1jing2 建议用 `torch.nn.functional.pad` 替换 `torch.zeros+torch.cat` 以提升张量填充性能。Napkin-AI 未在回复中明确拒绝,后续代码中已使用 `_pad` 方法。

结论:已采纳,使用 pad 实现。 · 已解决

BSND 布局支持讨论 设计

ping1jing2 建议支持 BSND 布局避免转置开销,Napkin-AI 指出内核有精度问题,无法支持 BSND,暂时维持 BNSD 布局并转置。

结论:暂不支持 BSND,添加 TODO 标注。 · 已解决

Rain Fusion 实现中长上下文的性能优化 性能

ping1jing2 强烈建议审阅整个 Rain Fusion 文件,特别是长上下文场景,指出 `_avgpool` 和 `_get_mask_index` 存在内存和延迟优化空间。Napkin-AI 认为当前版本已较优,未来可考虑 Triton 或设备语言优化。

结论:未采纳具体优化,保持当前实现。 · unresolved

缺少单元测试 测试

PR 未包含测试文件,Napkin-AI 在评论中表示计划在 sgl-kernel-npu 新版本发布后补充 CI 安装依赖和测试。

结论:待未来添加。 · unresolved

风险与影响

  • 缺少测试覆盖:三个新后端无对应单元测试,仅靠集成测试验证,回归风险较高。
  • 依赖自定义库:均依赖 sgl-kernel-npu 的 attentions 库,库接口变更或安装失败将导致后端不可用。
  • 性能风险:Rain Fusion 的 _avgpool_get_mask_index 被指存在内存/延迟优化空间,长序列场景下可能不达预期。
  • 兼容性限制:Laser Attention 不支持交叉注意力,无法用于所有模型。

用户影响:NPU 扩散模型用户现在可选用三种新注意力后端,通过 --attention-backend--attention-backend-config 灵活调整速度与精度。系统影响:代码量增加约 1150 行,新增三个后端模块,增加维护成本。团队影响:需跟进 sgl-kernel-npu 发布并规划集成测试。

缺少测试覆盖 依赖 sgl-kernel-npu 内核 长上下文性能风险 跨注意力兼容性限制

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论