Prhub

#43361 [8/n] Migrate merge_attn_states, mamba, sampler to torch stable ABI (continued)

原始 PR 作者 cleonard530 合并时间 2026-05-28 00:35 文件变更 13 提交数 10 评论 18 代码增减 +432 / -403

执行摘要

迁移 attention/mamba/sampler 内核到 torch stable ABI

延续之前 PR #38841 的工作,目标是将 vLLM 中所有 CUDA 扩展迁移到 torch stable ABI,从而减少 _C.abi3.so 中的不稳定符号数量,提升扩展的稳定性和可维护性。PR body 中提供的 torch-abi-audit 对比明确显示了改进。此外,由于主分支已有较大变动(如 topk 改用 persistent_topk),提交需要手动重写相应部分。

建议精读 csrc/libtorch_stable/torch_bindings.cppcsrc/libtorch_stable/ops.h,了解稳定 ABI 的注册和声明模式。对于需要迁移自定义内核的开发者,本 PR 提供了清晰的参考模板。同时关注常量正确性讨论,这在跨 ABI 时尤其重要。

讨论亮点

审阅中 janeyx99 关注了以下要点:

  • 常量正确性:在 csrc/persistent_topk.cuhcsrc/libtorch_stable/ops.h 中,lengths 等参数被改为 const,以确保与稳定 ABI schema 匹配,且不影响功能。janeyx99 表示同意这些更改。
  • 与原始 PR 的差异:由于主分支已采用 persistent_topk 而非 large_context_topk,cleonard530 手动重写了 topk.cu 的迁移,并额外修改了 csrc/persistent_topk.cuh 来适配。Janeyx99 确认这一差异合理。
  • DeviceGuard 的使用:在新迁移的文件中增加了 DeviceGuard,janeyx99 建议保持模块化,但未要求移除。Harry-Chen 也给予了批准。

实现拆解

  1. 文件搬迁:将 csrc/attention/merge_attn_states.cucsrc/sampler.cucsrc/topk.cu 以及 csrc/mamba/selective_scan_fwd.cu 等内核实现文件移动到 csrc/libtorch_stable/ 下对应的子目录中,同时将 csrc/mamba/static_switch.hselective_scan.h 也一并迁移。
  2. 头文件与声明迁移:在 csrc/libtorch_stable/ops.h 中添加 merge_attn_statesapply_repetition_penalties_top_k_per_row_prefill/decodepersistent_topkselective_scan_fwd 的稳定 ABI 声明(使用 torch::stable::Tensor);并在 csrc/ops.h 中删除对应的旧声明,保持 CPU 构建所需的部分不变。
  3. 算子注册迁移:在 csrc/libtorch_stable/torch_bindings.cppSTABLE_TORCH_LIBRARY_FRAGMENT 中添加 merge_attn_states、sampler 系列和 selective_scan_fwdops.def()ops.impl() 注册;在 csrc/torch_bindings.cppTORCH_LIBRARY_EXPAND 中删除对应的旧注册,仅保留量化等尚未迁移的操作。
  4. 内核实现适配:在各 .cu 文件中替换类型和宏:torch::Tensortorch::stable::Tensorat::Halftorch::headeronly::HalfTORCH_CHECKSTD_TORCH_CHECKC10_CUDA_CHECKSTD_CUDA_CHECK;添加 #include "torch_utils.h" 等必要头文件;调整 device_guard 使用 torch::stable::accelerator::DeviceGuard;对 persistent_topklengths 等参数添加 const 限定以匹配稳定 ABI 要求。
  5. 构建与 lint 配套:在 pyproject.toml[tool.codespell] 扩展标识列表中添加 sharedMemPerBlockOptin 来避免 pre-commit 拼写检查报错;更新 CMakeLists.txt 使新位置的内核参与稳定扩展的编译。
文件 模块 状态 重要度
csrc/libtorch_stable/torch_bindings.cpp 稳定 ABI 绑定 modified 6.92
csrc/torch_bindings.cpp 旧绑定注册 modified 6.62
csrc/libtorch_stable/ops.h 操作声明 modified 6.58
csrc/ops.h 旧操作声明 modified 6.22
csrc/libtorch_stable/topk.cu TopK 内核 renamed 5.84
csrc/libtorch_stable/sampler.cu 采样器内核 renamed 5.76

关键符号

merge_attn_states apply_repetition_penalties_ top_k_per_row_prefill top_k_per_row_decode persistent_topk selective_scan_fwd

关键源码片段

csrc/libtorch_stable/torch_bindings.cpp core-logic

核心绑定注册文件,在此添加了 merge_attn_states、sampler、mamba 的稳定 ABI 定义和实现注册,是迁移的核心入口。

// csrc/libtorch_stable/torch_bindings.cpp
// 在 STABLE_TORCH_LIBRARY_FRAGMENT 中添加以下注册:// Merge attn states(合并注意力状态)
// 实现 https://www.arxiv.org/pdf/2501.01005 第 2.2 节
// 用于在 split-KV 情况下合并部分注意力结果
ops.def(
    "merge_attn_states("
    "    Tensor! output,"
    "    Tensor!? output_lse,"
    "    Tensor prefix_output,"
    "    Tensor prefix_lse,"
    "    Tensor suffix_output,"
    "    Tensor suffix_lse,"
    "    int!? prefill_tokens_with_context,"
    "    Tensor? output_scale=None) -> ()");// 在 CUDA impl 段中绑定实现
ops.impl("merge_attn_states", TORCH_BOX(&merge_attn_states));// Apply repetition penalties(应用重复惩罚)
ops.def(
    "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
    "Tensor output_mask, Tensor repetition_penalties) -> ()");
ops.impl("apply_repetition_penalties_", TORCH_BOX(&apply_repetition_penalties_));// Optimized top-k per row(每行 top-k 优化)
ops.def(
    "top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
    "Tensor! indices, int numRows, int stride0, "
    "int stride1, int topK) -> ()");
ops.impl("top_k_per_row_prefill", TORCH_BOX(&top_k_per_row_prefill));ops.def(
    "top_k_per_row_decode(Tensor logits, int next_n, "
    "Tensor seq_lens, Tensor! indices, "
    "int numRows, int stride0, int stride1, int topK) -> ()");
ops.impl("top_k_per_row_decode", TORCH_BOX(&top_k_per_row_decode));ops.def(
    "persistent_topk(Tensor logits, Tensor lengths, Tensor! output, "
    "Tensor workspace, int k, int max_seq_len) -> ()");
ops.impl("persistent_topk", TORCH_BOX(&persistent_topk));// Mamba selective scan 前向内核
ops.def(
    "selective_scan_fwd(Tensor! u, Tensor! delta,"
    "Tensor! A, Tensor! B, Tensor! C,"
    "Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
    "bool delta_softplus,"
    "Tensor? query_start_loc,"
    "Tensor? cache_indices," // ... 更多参数省略
    ") -> ()");
ops.impl("selective_scan_fwd", TORCH_BOX(&selective_scan_fwd));
csrc/torch_bindings.cpp core-logic

旧绑定注册文件,移除了已迁移的 merge_attn_states、sampler 和 mamba 注册,只保留尚未迁移的量化等操作。

// csrc/torch_bindings.cpp(删除的片段)
// 以下注册被整体移除,迁移至稳定库:/* 删除前:
ops.def(
    "merge_attn_states("
    " Tensor! output,"
    ...
    " Tensor? output_scale=None) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);ops.def("apply_repetition_penalties_(...)");
ops.impl("apply_repetition_penalties_", torch::kCUDA, &apply_repetition_penalties_);ops.def("top_k_per_row_prefill(...)");
ops.impl("top_k_per_row_prefill", torch::kCUDA, &top_k_per_row_prefill);ops.def("top_k_per_row_decode(...)");
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);ops.def("persistent_topk(...)");
ops.impl("persistent_topk", torch::kCUDA, &persistent_topk);ops.def("selective_scan_fwd(...)");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
*/
csrc/libtorch_stable/ops.h core-logic

稳定 ABI 头文件,新增了迁移内核的函数声明,使用 torch::stable::Tensor 等类型。

// csrc/libtorch_stable/ops.h(新增声明)// Attention 内核(共享 CUDA/ROCm)
void merge_attn_states(
    torch::stable::Tensor& output,
    std::optional<torch::stable::Tensor> output_lse,
    const torch::stable::Tensor& prefix_output,
    const torch::stable::Tensor& prefix_lse,
    const torch::stable::Tensor& suffix_output,
    const torch::stable::Tensor& suffix_lse,
    const std::optional<int64_t> prefill_tokens_with_context,
    const std::optional<torch::stable::Tensor>& output_scale = std::nullopt);// Sampler 内核(共享 CUDA/ROCm)
void apply_repetition_penalties_(
    torch::stable::Tensor& logits, const torch::stable::Tensor& prompt_mask,
    const torch::stable::Tensor& output_mask,
    const torch::stable::Tensor& repetition_penalties);void top_k_per_row_prefill(const torch::stable::Tensor& logits,
    const torch::stable::Tensor& rowStarts,
    const torch::stable::Tensor& rowEnds,
    torch::stable::Tensor& indices, int64_t numRows,
    int64_t stride0, int64_t stride1, int64_t topK);void top_k_per_row_decode(const torch::stable::Tensor& logits, int64_t next_n,
    const torch::stable::Tensor& seqLens,
    torch::stable::Tensor& indices, int64_t numRows,
    int64_t stride0, int64_t stride1, int64_t topK);void persistent_topk(const torch::stable::Tensor& logits,
    const torch::stable::Tensor& lengths,
    torch::stable::Tensor& output,
    torch::stable::Tensor& workspace, int64_t k,
    int64_t max_seq_len);void selective_scan_fwd(
    const torch::stable::Tensor& u, const torch::stable::Tensor& delta,
    const torch::stable::Tensor& A, const torch::stable::Tensor& B,
    const torch::stable::Tensor& C,
    const std::optional<torch::stable::Tensor>& D_,
    // ... 更多参数
);

评论区精华

常量正确性(const-correctness) 正确性

janeyx99 在 `csrc/libtorch_stable/ops.h` 和 `csrc/persistent_topk.cuh` 中,对 `lengths` 等参数被改为 `const` 提出疑问,认为应确保与 schema 匹配。cleonard530 解释这些参数是只读的,改为 const 是正确的。

结论:janeyx99 同意 const 更改是正确的,并确认没有问题。 · 已解决

与原始 PR #38841 的差异 设计

cleonard530 指出 `topk.cu` 与原始 PR 差异很大,因为 main 分支已使用 persistent_topk 而非 large_context_topk,因此他手动重写了迁移。Janeyx99 确认了这一差异的合理性。

结论:认可的差异,无需额外修改。 · 已解决

DeviceGuard 的添加 设计

cleonard530 在新迁移的文件中增加了 `DeviceGuard`,janeyx99 建议保持模块化,但未强制要求移除。

结论:保留 DeviceGuard,后续 PR 可能统一处理。 · partially resolved

风险与影响

主要风险包括:

  • 回归风险:内核逻辑本身未改动,但类型和宏替换可能引入符号错位或编译失败。由于测试套件已覆盖相关路径(如 test_merge_attn_states.pytest_mamba_ssm.pytest_top_k_per_row.py),风险可控。
  • 兼容性风险:稳定 ABI 要求在 torch 版本间保持一致,但本次迁移仅使用官方支持的稳定类型,不会引入额外兼容问题。
  • 合并冲突:该 PR 解决了许多与 main 分支的冲突,特别在 topk.cusampler.cu 等文件上,未来继续迁移时仍需注意。
  • 构建风险:如果 CMakeLists.txt 未正确包含新位置的内核,可能导致链接失败。但从构建结果看已成功。

对用户透明,无功能变化。对开发者和构建系统而言,_C_stable_libtorch.abi3.so 的稳定符号数量不变(78个),_C.abi3.so 中的不稳定符号从 99 减少到 98,表明迁移正在逐步释放不稳定内核。后续阶段将继续迁移剩余 98 个不稳定符号。

常量正确性调整 内核迁移回归风险 构建配置变更

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论