首页 > 解决方案 > Pytorch:保存张量列表的最佳实践?

问题描述

我使用张量进行转换,然后将其保存在列表中。稍后,我将使用 将其作为数据集Dataset,然后最终DataLoader训练我的模型。为此,我可以简单地使用:

l = [tensor1, tensor2, tensor3,...]
dataset = Dataset.TensorDataset(l)
dataloader = DataLoader(dataset)

我想知道这样做的最佳做法是什么,以避免在大小l增长时RAM溢出?Iterator可以避免它吗?

标签: pytorch

解决方案


保存张量

for idx, tensor in enumerate(dataloader0):
    torch.save(tensor, f"{my_folder}/tensor{idx}.pt")

创建数据集

class FolderDataset(Dataset):
   def __init__(self, folder):
       self.files = os.listdir(folder)
       self.folder = folder
   def __len__(self):
       return len(self.files)
   def __getitem__(self, idx):
       return torch.load(f"{self.folder}/{self.files[idx]}")

然后你可以实现你自己的数据加载器。如果您不能将整个数据集保存在内存中,则需要加载一些文件系统。


推荐阅读