执行摘要
- 一句话:修复AscendFA在序列并行时头数参数不正确
- 推荐动作:建议合并。这是一个定位准确、改动量小的bug修复,修复了NPU AscendFA后端在序列并行下的关键崩溃。review建议已被采纳,代码清晰。后续可考虑添加单元测试覆盖sp-degree>1的场景。
功能与动机
修复AscendFlashAttention后端在使用Ulysses序列并行(ulysses-degree > 1)时的崩溃。用户给出的复现命令中sp-degree=4(tp-size=2,num-gpus=8),运行时出现NPU函数调用失败错误,原因是query的shape[1,5,25200,128]与预期[1,20,25200,128]不匹配——即传入的num_heads参数是用全局头数(20)而非切分后的局部头数(20/4=5)。
实现拆解
-
删除__init__中的头数缓存(ascend_fa.py第70-71行):移除self.num_heads = num_heads和self.num_kv_heads = num_kv_heads or num_heads,因为这些值在调用forward时可能因序列并行而不再有效。
-
在forward中动态获取头数(ascend_fa.py第80行):在forward方法开始时,通过query.shape[2]和key.shape[2]获取当前输入张量的实际头数。由于在transpose(1,2)之前,第2维(索引2)正好对应头数维度(BNSD布局中N的位置),此时query和key已经由上层调度根据sp-degree切分,因此取到的就是局部头数。
-
更新NPU调用参数(ascend_fa.py第94-95行):将原本的num_heads=self.num_heads和num_key_value_heads=self.num_kv_heads改为使用步骤2中动态获取的num_heads和num_key_value_heads变量,确保传入NPU API的实际头数与张量匹配。
-
配套调整:无测试文件变更,仅修改核心逻辑,共+3/-4行。
关键文件:
python/sglang/multimodal_gen/runtime/layers/attention/backends/ascend_fa.py(模块 注意力后端;类别 source;类型 core-logic;符号 AscendFAImpl.init, AscendFAImpl.forward): 核心变更文件,修复AscendFA后端在序列并行下头数参数错误导致的崩溃,从缓存值改为动态获取。
关键符号:AscendFAImpl.init, AscendFAImpl.forward
关键源码片段
python/sglang/multimodal_gen/runtime/layers/attention/backends/ascend_fa.py
核心变更文件,修复AscendFA后端在序列并行下头数参数错误导致的崩溃,从缓存值改为动态获取。
# 仅展示修改后的关键部分,省略未修改的 imports 和基类
class AscendFAImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
causal: bool,
softmax_scale: float,
num_kv_heads: int | None = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.causal = causal
self.softmax_scale = softmax_scale
# 注意:不再缓存 num_heads/num_kv_heads,
# 因为在序列并行(ulysses)时,
# 构造时传入的全局头数与 forward 时实际的局部头数不一致。
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata,
return_softmax_lse: bool = False,
) -> torch.Tensor:
# 在 transpose 之前(BNSD 布局,第 2 维是头数),
# 从输入张量动态获取实际头数,确保序列并行时参数正确。
num_heads = query.shape[2]
num_key_value_heads = key.shape[2]
mask = None
if self.causal:
seq_len = query.shape[1]
mask = torch.triu(
torch.ones(seq_len, seq_len, device=query.device), diagonal=1
).bool()
# transpose to bs, heads, seq_len, head_dim
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
output, lse = torch.ops.npu.npu_fused_infer_attention_score(
query,
key,
value,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
scale=self.softmax_scale,
input_layout="BNSD",
softmax_lse_flag=return_softmax_lse,
atten_mask=mask,
)
output = output.transpose(1, 2)
if return_softmax_lse:
return output, lse
return output
评论区精华
reviewer gemini-code-assist[bot] 指出:在forward中query.shape[1]同时用于获取seq_len和(在transpose前的版本中)head counts,两个不同含义使用同一索引可能造成维护混淆。虽然当前实现正确,但建议将head counts捕获到局部变量以提高可读性。
后续修改采纳了建议:最终实现中已改为在forward开始处(transpose前)通过query.shape[2]和key.shape[2]获取头数,与seq_len(query.shape[1])分开,解决了review中提出的混淆问题。
结论:review建议被采纳,最终提交中正确分离了seq_len和头数的获取。
- 使用query.shape[1]同时表示seq_len和head counts的混淆 (correctness): 最终实现中使用query.shape[2]和key.shape[2]获取头数,在transpose之前操作,避免了与seq_len的冲突,提高了可读性和健壮性。
风险与影响
- 风险:
- 回归风险低:变更仅影响AscendFA后端,且逻辑简单——从缓存值改为动态取值。在非序列并行场景下,query.shape[2]应与缓存的num_heads相同(因为Ulysses degree=1时不切分),因此行为不变。
- 性能影响:无,仅在forward中多了两次shape读取,开销可忽略。
- 兼容性:仅影响NPU + AscendFA后端,不影响其他后端。
- 测试覆盖缺失:PR未添加对应单元测试。考虑到这是针对特定sp-degree配置的bugfix,若后续重构该模块或变更forward参数顺序,可能再次引入类似问题。
- 影响:
- 用户影响:修复了NPU用户在使用Ulysses序列并行(如sp-degree=4,tp-size=2,num-gpus=8)时运行扩散模型(如Wan2.2-T2V)的崩溃问题,影响面局限在NPU + AscendFA + 序列并行组合场景。
- 系统影响:无,代码改动量小(+3/-4行),不涉及其他模块。
- 团队影响:对于维护NPU后端的开发人员,这是一个清晰的bug修复示例,说明序列并行下构造缓存与运行时实际参数可能不一致的风险。
- 风险标记:缺少测试覆盖, 核心路径变更
关联脉络
- PR #23198 [diffusion] Fix --warmup-resolutions hang with --enable-cfg-parallel: 同为diffusion模块在NPU上的bug修复,涉及ascend_fa.py同一目录的后端,且都是与并行策略相关的修复。
参与讨论