Prhub

#40845 [BE][Torch 2.12] Remove workaround code for fixed cublas issue

原始 PR 作者 Lucaskabela 合并时间 2026-04-29 12:07 文件变更 1 提交数 5 评论 9 代码增减 +11 / -11

执行摘要

移除 B200 批次不变性 workaround,统一 SM100 与 SM90 路径

在 PyTorch 2.12 环境下,cuBLASLt on B200 已是 batch-invariant 的(参考 pytorch/pytorch#181248)。原 workaround 是 torch-2.9 时代引入的,用于规避 B200 在 bs=1 时选择非 batch-invariant GEMV 路径的问题。现在不再需要,故移除,并统一 SM90 与 SM100 的代码路径。

值得精读。PR 展示了如何在上游修复后干净地剥离临时 workaround,同时注意了交叉平台安全(is_cuda() 保护)。是学习 vLLM 如何处理 GPU 架构差异和 PyTorch 版本兼容性的好例子。

讨论亮点
  • yewentao256: 指出 get_max_shared_memory_bytes 只对 CUDA 有效,但当前修改可能让 CPU worker 也进入该路径,存在风险。
    • Lucaskabela: 回复已 gated 到 CUDA cap 下。
  • yewentao256: 建议简化 else 分支注释,去掉“previous update”的提法。
    • Lucaskabela: 采纳建议并更新注释。

实现拆解

  1. 移除 SM100 特判:将原来的 if is_device_capability_family(100) or is_device_capability_family(80): 改为只检查 capability 80。这样 SM100 不再进入 Triton 持久矩阵乘法覆盖的分支。
  2. 统一 cuBLAS workspace 路径:SM100 与 SM90 一样进入 else 分支,设置 CUBLAS_WORKSPACE_CONFIGCUBLASLT_WORKSPACE_SIZE 以禁用 split-k 引起的不确定性。
  3. 修复 get_max_shared_memory_bytes 调用范围:将原来只在 SM100/80 分支中查询共享内存大小的逻辑独立出来,并加上 is_cuda() 检查,避免在非 CUDA 平台上误调用。
  4. 更新注释:将注释从解释历史 workaround 改为说明当前分支的选择理由(Ampere 需要 Triton 覆盖,而 Hopper/Blackwell 只需禁用 split-k)。
  5. 测试验证:在 B200 + torch 2.12 + triton 3.7.0 上运行 tests/v1/determinism/test_batch_invariance.py 全部 9 个用例通过。
文件 模块 状态 重要度
vllm/model_executor/layers/batch_invariant.py 模型执行器 modified 6.56

关键符号

enable_batch_invariant_mode

关键源码片段

vllm/model_executor/layers/batch_invariant.py data-contract

核心修改文件,修改了 enable_batch_invariant_mode 函数的条件分支,移除了 SM100 特殊路径,并调整了共享内存查询的保护逻辑。

def enable_batch_invariant_mode():
    # ... 全局变量声明 ...
​
    if _batch_invariant_MODE:
        return
​
    _batch_invariant_MODE = True
    _batch_invariant_LIB = torch.library.Library("aten", "IMPL")
​
    if current_platform.is_device_capability_family(80):
        # SM80 (Ampere) 不能依赖 cuBLASLt 的确定性,因此安装 Triton 持久矩阵乘法覆盖。
        _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
        _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
        _batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
        _batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
    else:
        # Hopper (SM90) 和 Blackwell (SM100):唯一的不确定性来源是 split-k,
        # 通过 cuBLAS workspace 配置禁用。
        _original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
        _original_cublaslt_workspace_size = os.environ.get("CUBLASLT_WORKSPACE_SIZE", None)
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
        os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
​
    # Triton bmm 和持久矩阵乘法内核读取此值来确定 FP16 N-tile 大小。
    # 在所有 CUDA 平台上无条件设置,因为 bmm 被覆盖。
    if current_platform.is_cuda():
        _fp16_block_size_n = 256 if get_max_shared_memory_bytes() > 106496 else 128

评论区精华

get_max_shared_memory_bytes 调用范围 正确性

yewentao256 指出 `get_max_shared_memory_bytes` 只对 CUDA 有效,担心 CPU worker 误入此路径。

结论:PR 作者意识到问题,随后在代码中加入了 `if current_platform.is_cuda():` 保护。 · 已解决

else 分支注释简化 style

yewentao256 建议不要在新的注释中提及“previous update”。

结论:PR 作者采纳建议,简化了注释,只描述当前逻辑。 · 已解决

风险与影响

  • 上游依赖变更风险:本修改依赖 PyTorch 2.12 的 cuBLAS 修复。如果用户在低于 2.12 的环境运行,B200 可能丢失 batch-invariant 保证。但 PR 面向 torch 2.12 升级,合理假设目标环境已升级。
  • 回归风险:B200 从 Triton 路径切换到 cuBLASLt 路径,尽管本地测试通过,但 CI 尚未在曾经失败的 B200 runner 上验证。不过 Triton 路径的移除减少了一个容易出现分歧的分支,长期看降低了回归可能性。
  • 非 CUDA 平台风险get_max_shared_memory_bytes 调用已用 is_cuda() 保护,避免在 CPU 等平台上调用出错。
  • 配置覆盖遗漏:原先 SM100 分支不设置 cuBLAS workspace 环境变量,现在设置;如果原流程依赖了省略这些变量的行为,可能有影响。但理论上配置只影响 split-k 行为,与 batch-invariance 目标一致。
  • 用户影响:B200 用户不再经过 Triton 持久矩阵乘法内核,转而使用 cuBLASLt,性能与确定性应与 H100 一致。所有用户不再需要为 B200 单独设置。
  • 系统影响:统一了 SM90 和 SM100 在 batch-invariance 模式下的行为,减少了平台特有的代码路径,便于未来维护。
  • 团队影响:移除了一段带有历史背景的 workaround 代码,降低了阅读负担。后续若 PyTorch 有类似变更,也可参考此 PR 的处理方式。
上游依赖变更 B200 特定路径修改 平台条件收紧

关联 Issue

#166735 [cuBLAS] Force tensor-core-no-reduction algo in `cuBLASLt` for `n=1` cases
#181248 [vllm] [2.12 regression][B200] test_batch_invariance: nondeterministic outputs 3/5 trials with FLASH_ATTN (B200 only, H100 passes)

完整报告

参与讨论