首页 > 解决方案 > Pytorch DataLoder 非常慢

问题描述

我对 Pytorch 的 DataLoader 有问题,因为它非常慢。

我做了一个测试来证明这一点,这里是代码:

data = np.load('slices.npy')
data = np.reshape(data, (-1, 1225))
data = torch.FloatTensor(data).to('cuda')
print(data.shape)
# ==> torch.Size([273468, 1225])

class UnlabeledTensorDataset(TensorDataset):
    def __init__(self, data_tensor):
        self.data_tensor = data_tensor
        self.samples = data_tensor.shape[0]

    def __getitem__(self, index):
        return self.data_tensor[index]
    
    def __len__(self):
        return self.samples

test_set = UnlabeledTensorDataset(data)
test_loader = DataLoader(test_set, batch_size=data.shape[0])

start = datetime.datetime.now()
with torch.no_grad():
    for batch in test_loader:
        print(batch.shape)     # ==> torch.Size([273468, 1225])
        y_pred = model(batch)
        loss = torch.sqrt(criterion(y_pred, batch))
        avg_loss = loss
print(round((datetime.datetime.now() - start).total_seconds() * 1000, 2))
# ==> 1527.57  (milliseconds)   !!!!!!!!!!!!!!!!!!!!!!!!

start = datetime.datetime.now()
with torch.no_grad():
    print(data.shape)     # ==> torch.Size([273468, 1225])
    y_pred = model(data)
    loss = torch.sqrt(criterion(y_pred, data))
    avg_loss = loss
print(round((datetime.datetime.now() - start).total_seconds() * 1000, 2))
# ==> 2.0     (milliseconds)    !!!!!!!!!!!!!!!!!!!!!!!!

我想使用 DataLoader 但我想要一种方法来解决这个缓慢的问题,有人知道为什么会这样吗?

标签: pythonpytorch

解决方案


时差对我来说似乎是合乎逻辑的:

  • 一方面,您正在循环test_loader并进行1225推理。

  • 另一方面,你正在做一个单一的推理。


推荐阅读