python - RuntimeError:/pytorch/torch/lib/TH/generic/THTensorMath.c:2864 处的张量大小不一致
问题描述
我正在尝试构建一个数据加载器,这就是它的样子
`class WhaleData(Dataset):
def __init__(self, data_file, root_dir , transform = None):
self.csv_file = pd.read_csv(data_file)
self.root_dir = root_dir
self.transform = transforms.Resize(224)
def __len__(self):
return len(os.listdir(self.root_dir))
def __getitem__(self, index):
image = os.path.join(self.root_dir, self.csv_file['Image'][index])
image = Image.open(image)
image = self.transform(image)
image = np.array(image)
label = self.csv_file['Image'][index]
sample = {'image': image, 'label':label}
return sample
trainset = WhaleData(data_file = '/mnt/55-91e8-b2383e89165f/Ryan/1234/train.csv',
root_dir = '/mnt/4d55-91e8-b2383e89165f/Ryan/1234/train')
train_loader = torch.utils.data.DataLoader(trainset , batch_size = 4, shuffle =True,num_workers= 2)
for i, batch in enumerate(train_loader):
(i, batch)
当我尝试运行这段代码时,我得到了这个错误,我确实得到了错误的性质,即我的所有图像可能不是相同的形状,而且我的图像也不是相同的形状,但如果我没有错只有当我将它们提供给网络时才会出现错误,因为图像都是不同的形状,但为什么会在这里抛出错误?任何关于我可能出错的地方的建议都会非常有帮助,如果需要,我很乐意提供任何额外的信息,
谢谢
RuntimeError: Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 42, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 116, in default_collate
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 116, in <dictcomp>
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 105, in default_collate
return torch.stack([torch.from_numpy(b) for b in batch], 0)
File "/usr/local/lib/python3.5/dist-packages/torch/functional.py", line 64, in stack
return torch.cat(inputs, dim)
RuntimeError: inconsistent tensor sizes at /pytorch/torch/lib/TH/generic /THTensorMath.c:2864
解决方案
当 PyTorch 尝试将图像堆叠到一个批处理张量中时会出现错误(参见torch.stack([torch.from_numpy(b) for b in batch], 0)
您的跟踪)。正如您所提到的,由于图像具有不同的形状,因此堆叠失败(即,如果所有这些张量都具有形状,(B, H, W)
则只能通过堆叠张量来创建张量)。B
(H, W)
注意:我不完全确定,但设置batch_size=1
fortorch.utils.data.DataLoader(...)
可能会消除此特定错误,因为它可能不再需要调用torch.stack()
)。
推荐阅读
- unit-testing - jest.fn() v/s jest.mock()?
- nginx - Nginx 与每个 URI 的单独文件夹完全匹配的 uri
- c# - 使用 FTD2XX.DLL 中的 FTCSPI.dll 函数,使用 FT2232H 设备
- python - 根据来自另一个列表的索引从列表列表中获取元素
- cuda - 使用 CUDA 的矩阵的多个 SVD
- html - 卷轴从何而来?
- flutter - 如何允许 webview 中的 mailto 方案颤动
- python - 由于缺少“请求”模块,无法在 Linux 中运行 python 脚本
- clickhouse - 如何提高clickhouse的查询速度
- jenkins - 如何在 Jenkins Pipeline 中获取 SonarQube taskId/report URL?