Prhub

#42481 [Bugfix] Fix layerwise reload alias-buffer corruption

原始 PR 作者 rasdani 合并时间 2026-05-16 06:20 文件变更 3 提交数 13 评论 6 代码增减 +154 / -0

执行摘要

修复逐层重载中别名缓冲区损坏导致 NaN 的问题

生产环境中发现 nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16 模型在反复权重重载后出现 NaN。追踪发现 MambaMixer2 注册了非持久化缓冲区 conv_weights,该缓冲区与子参数 conv1d.weight 共享存储,逐层重载回拷时错误地覆盖了已加载的子参数权重。Issue 评论中也确认了相同问题。

值得精读。该 PR 展示了如何在 PyTorch 中安全地检测和跳过共享存储的缓冲区,设计模式(预计算指针集合、异常安全处理)可复用于其他需要操作 tensor 别名的场景。review 过程中对性能优化和逻辑简化的讨论也体现了良好的工程实践。

讨论亮点

针对 #42481 的 gemini-code-assist[bot] review 评论:

  • 高优先级_copy_and_restore_kernel_tensors 中初始的别名检测逻辑存在缺陷——物化后层的 tensor 是新分配,存储指针不匹配;且 parameters.values() 仅包含直接参数,缺少子参数。建议简化为直接检查 buffer 是否存在于 layer._buffers 中。本 PR 后续提交采用了该建议。
  • 高优先级_is_non_persistent_parameter_alias_buffer 对每个 buffer 递归遍历子树,导致 O(B×P) 复杂度。建议在 capture_layer_to_meta 中提前计算参数存储指针集合并复用。本 PR 后续提交(Precompute layerwise alias storage pointers)将该集合作为参数传入,解决了该问题,且最终移除了 _is_parameter_alias_tensor,仅保留预计算指针集合的版本。

实现拆解

  1. 定义存储指针比较工具meta.py):新增 _tensor_storage_ptr 获取 tensor 的 untyped_storage().data_ptr(),处理 UninitializedParameter 和异常;新增 _parameter_storage_ptrs 递归收集模块子树所有参数的存储指针集合。
  2. 检测别名缓冲区meta.py):新增 _is_non_persistent_parameter_alias_buffer,判断 buffer 是否为非持久化且其存储指针在参数指针集合中。
  3. 修改元数据捕获meta.pycapture_layer_to_meta):在捕获缓冲区时,除 SKIP_TENSORS 外额外跳过别名缓冲区,避免其被记录到 restore_metadata
  4. 回拷阶段保护layerwise.py_copy_and_restore_kernel_tensors):遍历原始内核缓冲区时,检查 name not in layer._buffers 则跳过,因为跳过的别名缓冲区不存在于临时物化层中,避免 getattr 失败或错误拷贝。
  5. 测试配套test_reload.py):添加 _AliasedBufferLayer(自身模块内别名)、_ParentAliasedChildBufferLayer(父缓冲区→子参数别名)、_AliasedBufferWithUninitializedChildLayer(含未初始化参数)三个测试层;新增三个测试函数覆盖跳过逻辑、未初始化指针处理和完整重载流程验证。
文件 模块 状态 重要度
vllm/model_executor/model_loader/reload/meta.py 重载模块 modified 7.99
vllm/model_executor/model_loader/reload/layerwise.py 重载模块 modified 5.1
tests/model_executor/model_loader/test_reload.py 测试 modified 7.46

关键符号

_is_non_persistent_parameter_alias_buffer _tensor_storage_ptr _parameter_storage_ptrs capture_layer_to_meta _copy_and_restore_kernel_tensors test_layerwise_reload_skips_non_persistent_parameter_alias_buffers test_layerwise_reload_skips_child_parameter_alias_buffers test_capture_layer_to_meta_skips_uninitialized_parameter_storage_ptrs

关键源码片段

vllm/model_executor/model_loader/reload/meta.py data-contract

核心修改文件:新增存储指针比较工具和别名检测函数,修改 capture_layer_to_meta 跳过别名缓冲区,是修复的关键。

# vllm/model_executor/model_loader/reload/meta.py ( 核心修改 )def _is_non_persistent_parameter_alias_buffer(
    layer: torch.nn.Module,
    name: str,
    buffer: torch.Tensor,
    parameter_storage_ptrs: set[int],
) -> bool:
    # 只检查非持久化缓冲区
    if name not in layer._non_persistent_buffers_set:
        return False
    buffer_storage_ptr = _tensor_storage_ptr(buffer)
    return (
        buffer_storage_ptr is not None
        and buffer_storage_ptr in parameter_storage_ptrs
    )def _tensor_storage_ptr(tensor: torch.Tensor) -> int | None:
    # 未初始化参数直接返回 None
    if isinstance(tensor, UninitializedParameter):
        return None
    try:
        return tensor.untyped_storage().data_ptr()
    except (RuntimeError, ValueError):
        return Nonedef _parameter_storage_ptrs(layer: torch.nn.Module) -> set[int]:
    # 递归收集所有参数的存储指针(包括子模块),一次遍历即可复用
    return {
        storage_ptr
        for param in layer.parameters(recurse=True)
        if (storage_ptr := _tensor_storage_ptr(param)) is not None
    }# 修改后的 capture_layer_to_meta
def capture_layer_to_meta(layer: torch.nn.Module) -> LayerTensors:
    if layer.__class__.__name__ in SKIP_MODULES:
        return ({}, {})
    params, buffers = get_layer_params_buffers(layer)
    # 预计算参数指针集合 (O(P),之后每个 buffer 的检查为 O(1))
    parameter_storage_ptrs = _parameter_storage_ptrs(layer)
    return (
        {name: sanitize_layer_refs(to_meta_tensor(param), layer)
         for name, param in params.items()
         if name not in SKIP_TENSORS},
        {name: sanitize_layer_refs(to_meta_tensor(buffer), layer)
         for name, buffer in buffers.items()
         if name not in SKIP_TENSORS
         and not _is_non_persistent_parameter_alias_buffer(
             layer, name, buffer, parameter_storage_ptrs
         )},
    )
vllm/model_executor/model_loader/reload/layerwise.py core-logic

回拷阶段加入保护性判断,跳过不存在于物化层中的缓冲区,防止因缺失别名缓冲区导致的属性访问错误。

# vllm/model_executor/model_loader/reload/layerwise.py ( 回拷保护 )def _copy_and_restore_kernel_tensors(
    layer: torch.nn.Module, info: LayerReloadingInfo
):
    """将已处理的值拷贝回原始内核 tensor 存储,并恢复 cudagraph 引用。"""
    assert info.kernel_tensors is not None
    parameters, buffers = info.kernel_tensors
    for name, param in parameters.items():
        param.data.copy_(getattr(layer, name))
    for name, buffer in buffers.items():
        # 如果该缓冲区被元数据捕获阶段跳过(别名缓冲区),
        # 则物化后的 layer 中不存在该属性,直接跳过拷贝
        if name not in layer._buffers:
            continue
        buffer.data.copy_(getattr(layer, name))
    _place_kernel_tensors(layer, info)

评论区精华

回拷阶段别名检测逻辑缺陷 正确性

gemini-code-assist[bot] 指出 `_copy_and_restore_kernel_tensors` 中初始的别名检测 ( 通过 `_is_parameter_alias_tensor` 比较存储指针 ) 是错误的:物化后的 layer 包含新分配 tensor,指针必然不匹配;且 `parameters.values()` 只包含直接参数,遗漏子参数。建议简化为检查 buffer 是否在 `layer._buffers` 中。

结论:PR 作者采纳建议,移除了 `_is_parameter_alias_tensor`,改为直接对每个原始缓冲区判断 `name not in layer._buffers`,若不存在则跳过。 · 已解决

元数据捕获中 O(B×P) 性能隐患 性能

gemini-code-assist[bot] 指出 `_is_non_persistent_parameter_alias_buffer` 每次调用都会递归遍历模块子树计算参数指针,导致 O(B×P) 复杂度,在缓冲区较多时影响性能。建议预计算指针集合。

结论:PR 作者在后续提交中引入 `_parameter_storage_ptrs` 函数,在 `capture_layer_to_meta` 中预计算指针集合并传入检查函数,最终代码完全消除了重复递归。 · 已解决

风险与影响

主要风险来自存储指针比较可能对零大小 tensor、未初始化参数或自定义 storage 产生误判。函数内部通过 try-exceptUninitializedParameter 检查处理了已知异常情况。跳过别名缓冲区仅影响非持久化缓冲区,此类缓冲区本质上是参数的视图,重载参数后视图自动有效,因此跳过是安全的。回拷保护只跳过不在物化层中的缓冲区,不会遗漏必要的拷贝。总体风险可控。

用户/系统:修复了 MambaMixer2 等使用非持久化参数别名缓冲区的模型在逐层重载时产生 NaN 的严重错误,保障了模型训练/推理中权重重载的正确性。影响范围限于使用 layerwise reload 功能的场景。团队:增加了约 150 行测试代码,提高了对类似别名模式的防御能力。代码复杂度略有上升,但通过清晰的辅助函数和 precompute 模式保持了可维护性。

核心路径变更 存储指针跨平台风险 UninitializedParameter 边界处理

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论