首页 > 解决方案 > Pytorch 中的内存泄漏:对象检测

问题描述

我正在研究PyTorch上的对象检测教程。原始教程适用于给定的几个时期。我将它扩展到大时代并遇到内存不足错误。

我试图调试它并发现一些有趣的东西。这是我正在使用的工具:

def debug_gpu():
    # Debug out of memory bugs.
    tensor_list = []
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                tensor_list.append(obj)
        except:
            pass
    print(f'Count of tensors = {len(tensor_list)}.')

我用它来监控训练一个 epoch 的记忆:

def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    ...
    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        # inference + backward + optimization
        debug_gpu()

输出是这样的:

Count of tensors = 414.
Count of tensors = 419.
Count of tensors = 424.
Count of tensors = 429.
Count of tensors = 434.
Count of tensors = 439.
Count of tensors = 439.
Count of tensors = 444.
Count of tensors = 449.
Count of tensors = 449.
Count of tensors = 454.

如您所见,垃圾收集器跟踪的张量数量不断增加。

要执行的相关文件可以在这里找到。

我有两个问题: 1. 是什么阻碍了垃圾收集器释放这些张量?2. 内存不足错误怎么办?

标签: python-3.xdebuggingmemory-leakspytorch

解决方案


  1. 我如何识别错误?在tracemalloc的帮助下,我拍摄了两个快照,其中有数百次迭代。本教程将向您展示它很容易遵循。

  2. 是什么导致错误? rpn.anchor_generator._cache在 Pytorch 中是一个dict跟踪网格锚点的 python。它是检测模型的一个属性,大小随着每个提议而增加。

  3. 如何解决?model.rpn.anchor_generator._cache.clear()在训练迭代结束时放置一个简单的旁路。


我已经向PyTorch提交了一个修复程序。自torchvision 0.5 以来,您可能不会出现OOM 错误。


推荐阅读