Prhub

#23646 [MUSA][Diffusion] Fix fa3 API on MT MUSA

原始 PR 作者 wenqf11 合并时间 2026-04-29 04:01 文件变更 2 提交数 1 评论 12 代码增减 +37 / -51

执行摘要

修复 MUSA 设备上 Flash Attention v3 的支持与 API 调用

在 MT MUSA 设备上运行 FlashAttention v3 时出现兼容性问题。原 _is_fa3_supported 仅在 CUDA 环境下有效,且未检测 MUSA 设备;flash_attn_varlen_func 的调用方式不够稳健,无法适应 MUSA 后端的参数要求。

值得精读,尤其对关注多硬件支持(Moore Threads)的开发者。展示如何将 CUDA 专有函数扩展至其他 GPU 架构,以及关键字参数调用的最佳实践。

讨论亮点
  • 简化条件逻辑gemini-code-assist[bot] 建议将嵌套条件改为扁平返回结构,作者采用了建议。
  • 跳过理由措辞yeahdongcn 指出 "MUSA >=mp31" 应改为 "MUSA mp31",作者回复 done,但最终 commit 仍包含 >=,可能为有意保留。

实现拆解

  1. 修改 _is_fa3_supportedpython/sglang/jit_kernel/flash_attention_v3.py):
    • 优先获取设备能力(通过 get_device_capability),若为 MUSA 则返回 major >= 3;否则检查 CUDA 版本和计算能力。
    • 移除了对 torch.cuda.get_device_capability 的直接依赖,使用通用工具函数。
  2. 调整 flash_attn_varlen_func 调用方式(同一文件):
    • 将间接的 _call_fa3_kernel 调用改为直接调用内核函数,并使用关键字参数传递所有参数,确保不同后端(CUDA/MUSA)均能正确匹配参数位置。
  3. 测试文件净化test_flash_attention_3.py):
    • 删除本地独立的 is_fa3_supported 函数,改为从源码模块导入 _is_fa3_supported
    • 更新两个 @pytest.mark.skipif 条件,关联到导入的函数,并修正跳过原因描述。
文件 模块 状态 重要度
python/sglang/jit_kernel/flash_attention_v3.py JIT 内核 modified 7.27
python/sglang/jit_kernel/tests/test_flash_attention_3.py JIT 内核测试 modified 5.35

关键符号

_is_fa3_supported flash_attn_varlen_func

关键源码片段

python/sglang/jit_kernel/flash_attention_v3.py core-logic

核心修改文件:设备支持逻辑调整和 API 调用优化

# 修改后的 _is_fa3_supported 函数(位于 flash_attention_v3.py)
@cache_once
def _is_fa3_supported(device=None) -> bool:
    # 此函数用于判断当前设备是否支持 Flash Attention v3
    # 之前仅支持 CUDA 且依赖 torch.cuda.get_device_capability
    # 现在通过通用工具函数同时支持 MUSA
    major, minor = get_device_capability() # 不与特定后端耦合
    if is_musa():
        # MUSA 设备:需要计算能力 >= 3(例如 S5000 对应 major=5)
        return major >= 3
    if torch.version.cuda is not None and torch.version.cuda >= "12.3":
        # CUDA 设备:保持原有条件(sm80/sm90)
        return major == 9 or major == 8
    return False

# flash_attn_varlen_func 调用方式调整(同一文件)
# 原调用:
# return _call_fa3_kernel(_load_fa3_kernels()["flash_attn_varlen_func"], q, k, v, ...)
# 改为显式关键字参数,避免位置依赖:
return _load_fa3_kernels()["flash_attn_varlen_func"](
    q=q,
    k=k,
    v=v,
    cu_seqlens_q=cu_seqlens_q,
    cu_seqlens_k=cu_seqlens_k,
    max_seqlen_q=max_seqlen_q,
    max_seqlen_k=max_seqlen_k,
    seqused_q=seqused_q,
    seqused_k=seqused_k,
    page_table=page_table,
    softmax_scale=softmax_scale,
    causal=causal,
    qv=qv,
    q_descale=q_descale,
    k_descale=k_descale,
    v_descale=v_descale,
    window_size=window_size,
    attention_chunk=attention_chunk,
    softcap=softcap,
    num_splits=num_splits,
    pack_gqa=pack_gqa,
    sm_margin=sm_margin,
    return_softmax_lse=return_softmax_lse,
    sinks=sinks,
    out=out,
)

python/sglang/jit_kernel/tests/test_flash_attention_3.py test-coverage

测试文件跟随源文件调整,删除重复函数并导入统一版本

# 测试文件中的关键变更
from sglang.jit_kernel.flash_attention_v3 import _is_fa3_supported # 新增导入# 删除本地重复的 is_fa3_supported 函数(行 22-36)# 更新 skipif 装饰器(共两处)
@pytest.mark.skipif(
    not _is_fa3_supported(),
    reason="flash_attn at sgl-kernel is only supported on CUDA sm90, sm80 or MUSA >= mp31",
)
# 第二处类似

评论区精华

简化 _is_fa3_supported 条件逻辑 设计

gemini-code-assist[bot] 建议使用扁平结构替代嵌套 if,提供参考代码。

结论:作者最终采用了扁平 if-return 结构,与建议一致。 · 已解决

跳过理由措辞修正 style

yeahdongcn 指出 'MUSA >= mp31' 应改为 'MUSA mp31'。

结论:作者回复已处理,但最终 commit 仍保留 '>=',可能为有意保留或未及时更新。 · 已解决

风险与影响

  • 非 MUSA 设备:仅将 torch.cuda.get_device_capability 替换为通用函数 get_device_capability,回归风险低。
  • MUSA 设备:首次启用 fa3 支持,若设备计算能力低于 major>=3 则静默跳过,属于合理降级;但缺少 MUSA 上的精度对比测试,可能存在数值偏差风险。
  • 兼容性:参数改为关键字调用,对 CUDA 后端无影响,但需确保 MUSA 后端内核签名匹配。
  • 用户:MT MUSA 设备用户现可使用 FA3 加速注意力计算,预期提升 TTFT 性能。
  • 系统:无。
  • 开发者:测试复用源码逻辑,减少重复维护;硬件抽象层改进为其他设备扩展提供参考。
MUSA 支持首次启用 缺少精度对比测试

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论