Prhub

#43565 [XPU] support MTP of gdn attention

原始 PR 作者 mayuyuace 合并时间 2026-05-29 17:10 文件变更 1 提交数 2 评论 4 代码增减 +40 / -8

执行摘要

XPU GDN 注意力支持 MTP 推测解码

XPU 平台之前不支持 GDN 注意力的推测解码,代码中有 assert attn_metadata.spec_sequence_masks is None 的硬断言。本 PR 旨在解除这一限制,使 XPU 能够与 CUDA 平齐,支持 MTP 推测解码,从而提升推理吞吐量。PR body 附带了在 Qwen3-Next-80B-A3B-Instruct 模型上的 lm_eval 结果和吞吐量对比数据,表明支持 MTP 后性能有显著提升。

该 PR 功能明确、改动集中,值得相关开发人员精读。关注的要点:

  • 如何将推测解码元数据从 attention metadata 提取并传递给底层内核。
  • 使用局部变量统一管理内核参数的模式,便于后续扩展。
  • 与 CUDA 端同类实现(参考 qwen_gdn_linear_attn.py)的对比可加深对跨平台一致性设计的理解。
  • 自动化 review 中提出的代码质量建议虽未完全采纳,但可作为后续代码清洁的切入点。
讨论亮点

gemini-code-assist[bot] 提出的自动化 review 主要关注代码质量和一致性:

  • 类型转换逻辑简化:建议不要依赖 spec_sequence_masks 来决定是否转换索引类型,而应基于张量本身的性质检查,并建议也将 num_accepted_tokens 转为 int32 以确保一致性。
  • 连续性保证:建议对 spec_state_indices_tensor 也调用 .contiguous(),与 non_spec_state_indices_tensor 的处理保持一致。
  • 移除多余的 type: ignore 注释:指出局部变量上附带的 # type: ignore[attr-defined] 不再正确,建议清理。
  • yma11 提出简化建议:认为可以保留原内联方式,避免引入过多局部变量。
  • 作者回复:作者引用 CUDA 端代码,指出其采用了相同的局部变量模式,且认为不影响性能。
    所有讨论均已关闭,PR 获得 jikunshang 的批准。

实现拆解

  1. 移除禁止推测解码的断言:删除 # TODO: xpu does not support speculative decoding yet 注释及紧跟的 assert attn_metadata.spec_sequence_masks is None,这是开启功能的准入开关。
  2. 提取推测解码相关元数据为局部变量:将原来直接通过 attn_metadata.xxx 传入内核的多个属性,改为先提取到局部变量(如 num_actual_tokensnum_accepted_tokensnum_spec_decodesspec_query_start_locspec_token_indxspec_state_indices_tensorspec_sequence_masks 等),并补充了之前缺失的 spec_query_start_locspec_token_indxspec_state_indices_tensornum_accepted_tokens。新变量 non_spec_token_indx 替代了原来的 non_spec_state_indices_tensor 的单一使用。
  3. 处理推测序列掩码和类型转换:当 spec_sequence_masks 非空时,将 non_spec_token_indxspec_token_indx 转为 int32 类型;确保 non_spec_state_indices_tensor 连续,并新增对 spec_state_indices_tensor 的连续性保证(但未在代码中显式调用 .contiguous(),review 中有建议)。
  4. 更新内核参数传递:将原来直接传递 attn_metadata.xxx 的地方替换为局部变量,并新增 num_spec_decodesspec_query_start_locspec_token_indxspec_state_indices_tensornum_accepted_tokens 等参数传给 torch.ops._xpu_C.gdn_attention 内核,同时移除不再使用的 # type: ignore[attr-defined] 注释。
文件 模块 状态 重要度
vllm/_xpu_ops.py XPU 操作 modified 6.43

关键符号

_gdn_attention_core_xpu_impl

关键源码片段

vllm/_xpu_ops.py core-logic

唯一的变更文件,包含 GDN 注意力内核的 XPU 实现。通过移除断言、提取元数据、新增参数实现了对 MTP 推测解码的支持。

# vllm/_xpu_ops.py 中重构后的 _gdn_attention_core_xpu_impl
# 该函数是 XPU 上 GDN 注意力的 SYCL 内核封装,现支持 MTP 推测解码def _gdn_attention_core_xpu_impl(
    core_attn_out: torch.Tensor,
    z: torch.Tensor,
    projected_states_qkvz: torch.Tensor,
    projected_states_ba: torch.Tensor,
    layer_name: str,
) -> None:
    """Custom op wrapping the XPU SYCL GDN kernel for torch.compile."""
    from vllm.forward_context import get_forward_context
    from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
​
    forward_context = get_forward_context()
    self = forward_context.no_compile_layers[layer_name]
    attn_metadata_raw = forward_context.attn_metadata
​
    if attn_metadata_raw is None:
        return
​
    assert isinstance(attn_metadata_raw, dict)
    attn_metadata = attn_metadata_raw[self.prefix]
    assert isinstance(attn_metadata, GDNAttentionMetadata)
​
    # 以下提取所有推测解码相关的元数据到局部变量
    # 之前此处有断言阻止推测解码,现已移除
    num_actual_tokens = attn_metadata.num_actual_tokens
    num_accepted_tokens = attn_metadata.num_accepted_tokens
​
    num_prefills = attn_metadata.num_prefills
    num_decodes = attn_metadata.num_decodes
    num_spec_decodes = attn_metadata.num_spec_decodes # 新增
​
    has_initial_state = attn_metadata.has_initial_state
​
    non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
    non_spec_token_indx = attn_metadata.non_spec_token_indx
    non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor
    non_spec_state_indices_tensor = (
        non_spec_state_indices_tensor.contiguous()
        if non_spec_state_indices_tensor is not None
        else None
    )
​
    spec_query_start_loc = attn_metadata.spec_query_start_loc # 新增
    spec_token_indx = attn_metadata.spec_token_indx # 新增
    spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # 新增
​
    spec_sequence_masks = attn_metadata.spec_sequence_masks
    if spec_sequence_masks is not None:
        # 类型转换确保与内核期望的 int32 一致
        if non_spec_token_indx is not None:
            non_spec_token_indx = non_spec_token_indx.to(torch.int32)
        if spec_token_indx is not None:
            spec_token_indx = spec_token_indx.to(torch.int32)
​
    conv_weights = self.conv1d.weight.view(
        self.conv1d.weight.size(0), self.conv1d.weight.size(2)
    )
​
    torch.ops._xpu_C.gdn_attention(
        core_attn_out, z, projected_states_qkvz, projected_states_ba,
        self.num_k_heads, self.num_v_heads, self.head_k_dim, self.head_v_dim,
        conv_state=self.kv_cache[0],
        ssm_state=self.kv_cache[1],
        conv_weights=conv_weights,
        conv_bias=self.conv1d.bias,
        activation=self.activation,
        A_log=self.A_log,
        dt_bias=self.dt_bias,
        num_prefills=num_prefills,
        num_decodes=num_decodes,
        num_spec_decodes=num_spec_decodes, # 新增参数
        has_initial_state=has_initial_state,
        non_spec_query_start_loc=non_spec_query_start_loc,
        non_spec_token_indx=non_spec_token_indx, # 新增
        non_spec_state_indices_tensor=non_spec_state_indices_tensor,
        spec_query_start_loc=spec_query_start_loc, # 新增
        spec_token_indx=spec_token_indx, # 新增
        spec_state_indices_tensor=spec_state_indices_tensor, # 新增
        num_accepted_tokens=num_accepted_tokens, # 新增
        num_actual_tokens=num_actual_tokens,
        tp_size=self.tp_size,
        reorder_input=not self.gqa_interleaved_layout,
    )

评论区精华

简化类型转换逻辑 设计

`gemini-code-assist[bot]` 建议不要依赖 `spec_sequence_masks` 决定是否转换索引类型,而应基于张量自身检查,并建议也将 `num_accepted_tokens` 转为 int32。

结论:作者未采纳,当前实现保留依赖 `spec_sequence_masks` 的条件转换,保持与 CUDA 端一致。 · 已解决

确保 spec_state_indices_tensor 连续性 正确性

`gemini-code-assist[bot]` 建议对 `spec_state_indices_tensor` 也调用 `.contiguous()`,与 `non_spec_state_indices_tensor` 的处理一致。

结论:代码中未添加,可能作者认为内核内部处理或输入保证连续。 · 已解决

移除多余的 type: ignore 注释 style

`gemini-code-assist[bot]` 指出局部变量上的 `# type: ignore[attr-defined]` 不再正确,建议清理。

结论:代码中已移除这些注释,但部分可能残留。 · 已解决

避免不必要的局部变量 设计

`yma11` 建议保持内联方式,避免引入过多局部变量。

结论:作者回复引用 CUDA 端相同模式,认为不会影响性能,因此保留。 · 已解决

风险与影响

  1. 回归风险:低:仅修改 _gdn_attention_core_xpu_impl 函数,移除断言并扩展参数,逻辑上兼容非推测解码场景(spec_sequence_masks 为 None 时,相关局部变量也会为 None,内核需能处理)。但因缺少直接针对此变更的单元测试,无法保证在边界条件下(如混合 prefill/decode 与推测解码)无问题。
  2. 性能风险:低:引入局部变量赋值,开销极小,作者也认为不影响性能。
  3. 集成风险:依赖 vllm-xpu-kernels 仓库的对应内核更新(PR #368),若内核未同步更新则可能出现接口不匹配。PR body 中明确提及需要内核更新。

影响范围:仅 XPU 平台上的 GDN 注意力路径,涉及 MTP(Multi-Token Prediction)推测解码功能。
影响程度:中等。对使用 XPU + Qwen 类模型并启用推测解码的用户,可显著提升吞吐量(PR 中给出 B60 上的对比数据)。对其他模型或非推测解码场景无影响。
团队影响:需要 XPU 内核团队配合更新底层的 gdn_attention 内核以支持新增参数。

缺少测试覆盖 依赖外部内核更新

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论