首页 > 解决方案 > Python pytorch 函数过快消耗内存

问题描述

我正在使用 pytorch 编写一个函数,该函数通过转换器模型提供输入,然后通过计算沿特定轴的平均值(使用掩码定义的索引子集)来压缩最后一个嵌入层。由于模型的输出非常非常大,我需要对输入进行批量处理。

我的问题与此功能的逻辑无关,因为我相信我有正确的实现。我的问题是我编写的函数过快地消耗内存并且实际上使其无法使用。

这是我的功能:

def get_chunk_embeddings(encoded_dataset, batch_size):
  chunk_embeddings = torch.empty([0,768])
  for i in range(len(encoded_dataset['input_ids'])//batch_size):
    input_ids = encoded_dataset['input_ids'][i*batch_size:i*batch_size + batch_size]
    attention_mask = encoded_dataset['attention_mask'][i*batch_size:i*batch_size + batch_size]
    embeddings = model.forward(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state']
    embeddings = embeddings * attention_mask[:,:,None]
    embeddings = embeddings.sum(dim=1)/attention_mask.sum(dim=1)[:,None]
    chunk_embeddings = torch.cat([chunk_embeddings, embeddings],0)
  return chunk_embeddings

现在让我们谈谈内存(下面的数字假设我传递了 8 的 batch_size):

所以根据我的理解,我应该能够允许chunk_embeddings增长到:25GB - 413MB - 0.48GB - 0.413MB - 4.096KB - 12.6MB ~= 24 GB。足以进行近 100 万次迭代。

在这里,我将通过一个例子来说明我正在经历的事情:

  1. 在运行我的函数之前,google colab 告诉我我有足够的内存

在此处输入图像描述

  1. 现在,为了举例,我将运行该函数(仅 3 次迭代)为了明确起见,我将其放在 for 循环的末尾: if (i == 2):return chunk_embeddings

  2. 现在我运行代码val = get_chunk_embeddings(train_encoded_dataset, 8) 所以即使只有 3 次迭代,不知何故我消耗了将近 5.5 GB 的 RAM。

在此处输入图像描述

为什么会这样?此外,在我从函数返回后,所有局部变量都应该被删除,而且没有val这么大的方法。

有人可以告诉我我做错了什么或不理解吗?如果需要更多信息,请告诉我。

标签: pythonmemory-managementmemory-leakspytorchram

解决方案


为了扩展@GoodDeeds 的答案,默认情况下,pytorch.nn模块(模型)中的计算会创建计算图并保留梯度(除非您正在使用with torch.no_grad()或类似的东西。这意味着在循环的每次迭代中,嵌入的计算图存储在张量embeddings中 .embeddings.grad可能比embeddings自身大得多,因为每个层值相对于每个前一层值的梯度保持不变。接下来,由于您使用torch.cat,因此将关联的梯度附加embeddingsdchunk_embeddings。这意味着经过几次迭代,chunk_embeddings存储了大量的梯度值,这就是你记忆的去向。有几个解决方案:

  1. 如果您需要使用块嵌入进行反向传播(即训练),您应该在循环内移动损失计算和优化器步骤,以便之后自动清除梯度。

  2. 如果此函数仅在推理期间使用,您可以使用 完全禁用梯度计算(这也应该稍微加快计算速度)torch.no_grad(),或者您可以按照注释中的建议在每次迭代时使用torch.detach()on 。embeddings

例子:

def get_chunk_embeddings(encoded_dataset, batch_size):
  with torch.no_grad():
    chunk_embeddings = torch.empty([0,768])
    for i in range(len(encoded_dataset['input_ids'])//batch_size):
      input_ids = encoded_dataset['input_ids'][i*batch_size:i*batch_size + batch_size]
      attention_mask = encoded_dataset['attention_mask'][i*batch_size:i*batch_size + batch_size]
      embeddings = model.forward(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state']
      embeddings = embeddings * attention_mask[:,:,None]
      embeddings = embeddings.sum(dim=1)/attention_mask.sum(dim=1)[:,None]
      chunk_embeddings = torch.cat([chunk_embeddings, embeddings],0)
  return chunk_embeddings
    

推荐阅读