Prhub

#39466 [XPU] Enable torch.compile for XPU GDN attention

原始 PR 作者 yuwenzho 合并时间 2026-04-24 16:26 文件变更 3 提交数 14 评论 23 代码增减 +81 / -41

执行摘要

将 XPU GDN kernel 包装为自定义 op 以支持 torch.compile

Dynamo无法跟踪直接调用torch.ops._xpu_C.gdn_attention的代码,导致torch.compile时丢失该kernel。通过注册为自定义op,使编译器能够识别并正确处理该算子。

值得关注自定义op注册模式,这是vllm中处理torch.compile兼容性的标准做法。建议阅读vllm/_xpu_ops.py中的注册流程和forward_xpu的简化逻辑,可对比原先的内联版本理解抽象层次。

讨论亮点

核心讨论:为什么需要自定义op

jikunshang: "what's the relationship with torch.compile?"
作者: 之前的_xpu_C.gdn_attention未注册为vllm自定义op,编译器无法处理。

导入顺序问题

jikunshang: 建议在函数内部导入GDNAttentionMetadata以避免循环依赖。
作者采纳,将导入移至_gdn_attention_core_xpu_impl函数内部。

z tensor初始化风险

Copilot: 当attn_metadata为None时,op直接返回,z保持torch.empty,后续norm使用未初始化数据产生不确定输出。建议在返回前z.zero_()
未在最终代码中采纳(原逻辑已有相似问题),但可能在实际使用中因profiling路径短而未被触发。

实现拆解

  1. 注册自定义op:在vllm/_xpu_ops.py中新增_gdn_attention_core_xpu_impl_gdn_attention_core_xpu_fake函数,通过direct_register_custom_op注册opgdn_attention_core_xpu。实现函数从forward context获取层对象和注意力元数据,调用底层SYCL kernel;fake函数直接返回,用于torch.compile图构建。

  2. 简化XPU forward路径:修改vllm/model_executor/layers/mamba/gdn_linear_attn.py中的forward_xpu方法,将原先内联的元数据提取和kernel调用逻辑替换为一行torch.ops.vllm.gdn_attention_core_xpu(...)调用,大幅减少代码量。

  3. 纳入编译系统:在vllm/config/compilation.py_attention_ops列表中添加"vllm::gdn_attention_core_xpu",使编译/图分割逻辑能正确识别该op。

  4. 配套调整(讨论中涉及但未包含在此PR):XPU上跳过cudagraph内存估计的修改在PR #39977中单独处理。

文件 模块 状态 重要度
vllm/_xpu_ops.py 算子注册 modified 7.78
vllm/model_executor/layers/mamba/gdn_linear_attn.py 注意力层 modified 6.93
vllm/config/compilation.py 编译配置 modified 4.35

关键符号

_gdn_attention_core_xpu_impl _gdn_attention_core_xpu_fake forward_xpu

关键源码片段

vllm/_xpu_ops.py core-logic

注册自定义 op 的核心文件,包含 op 实现和 fake 实现,决定了 torch.compile 如何跟踪 XPU GDN kernel。

# 自定义 op 实现:从 forward context 获取层和注意力元数据,调用 SYCL kernel
# 注意导入在函数内部以避免循环依赖
def _gdn_attention_core_xpu_impl(
    core_attn_out: torch.Tensor,
    z: torch.Tensor,
    projected_states_qkvz: torch.Tensor,
    projected_states_ba: torch.Tensor,
    layer_name: str,
) -> None:
    from vllm.forward_context import get_forward_context
    from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
​
    forward_context = get_forward_context()
    self = forward_context.no_compile_layers[layer_name]
    attn_metadata_raw = forward_context.attn_metadata
​
    if attn_metadata_raw is None:
        return # profiling 时无元数据,跳过 kernel,z 保持 empty
​
    assert isinstance(attn_metadata_raw, dict)
    attn_metadata = attn_metadata_raw[self.prefix]
    assert isinstance(attn_metadata, GDNAttentionMetadata)
    assert attn_metadata.spec_sequence_masks is None # XPU 暂不支持推测解码
​
    conv_weights = self.conv1d.weight.view(
        self.conv1d.weight.size(0), self.conv1d.weight.size(2)
    )
​
    torch.ops._xpu_C.gdn_attention(
        core_attn_out, z, projected_states_qkvz, projected_states_ba,
        self.num_k_heads, self.num_v_heads, self.head_k_dim, self.head_v_dim,
        conv_state=self.kv_cache[0], ssm_state=self.kv_cache[1],
        conv_weights=conv_weights, conv_bias=self.conv1d.bias,
        activation=self.activation, A_log=self.A_log, dt_bias=self.dt_bias,
        num_prefills=attn_metadata.num_prefills,
        num_decodes=attn_metadata.num_decodes,
        has_initial_state=attn_metadata.has_initial_state,
        non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc,
        non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor,
        num_actual_tokens=attn_metadata.num_actual_tokens,
        tp_size=self.tp_size,
        reorder_input=not self.gqa_interleaved_layout,
    )
​
​
def _gdn_attention_core_xpu_fake(
    core_attn_out: torch.Tensor,
    z: torch.Tensor,
    projected_states_qkvz: torch.Tensor,
    projected_states_ba: torch.Tensor,
    layer_name: str,
) -> None:
    return # fake impl: no-op

评论区精华

为什么需要自定义 op 设计

jikunshang 询问与 torch.compile 的关系,作者解释:之前的 `_xpu_C.gdn_attention` 未注册为 vllm 自定义 op,编译器无法处理。

结论:确认需要通过注册自定义 op 来实现 torch.compile 的 traceability。 · 已解决

z tensor 未初始化风险 正确性

Copilot 指出当 attn_metadata 为 None 时,op 直接返回,z 保持 `torch.empty`,后续 norm 使用未初始化数据产生不确定输出。建议在返回前 `z.zero_()`。

结论:未采纳,认为原逻辑已有相似问题且 profile 阶段后续不会实际使用 z(可能因重新初始化)。风险仍存在。 · 未解决

导入顺序和循环依赖 设计

jikunshang 建议在函数内部导入 GDNAttentionMetadata 以避免循环依赖。

结论:作者采纳,将 GDNAttentionMetadata 导入移到函数内部。 · 已解决

风险与影响

  1. z tensor未初始化forward_xpuz初始化为torch.empty_like,若op在attn_metadata为None时直接返回(profile阶段),后续输出投影使用未初始化的z可能导致NaN或错误。原逻辑已有此问题,但可能因profile后重新初始化而逃逸。
  2. 导入顺序风险(已缓解):通过将GDNAttentionMetadata的导入移到函数内部,避免了模块级循环依赖。
  3. 平台专用风险:变更仅影响XPU平台,对其他平台无影响,但若XPU上op注册失败,forward_xpu会直接报错。
  4. 缺少测试覆盖:未看到针对新op的单元测试或集成测试。

启用torch.compile后,XPU上使用GDN attention的模型(如Qwen3.5等Mamba架构)可获得编译优化带来的性能提升。影响范围限于XPU平台,且仅影响注意力计算路径。代码量减少,可维护性提升。

缺少测试覆盖 z tensor 未初始化风险 XPU 平台专用

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论