pytorch - 如何将图像变换应用于图像列表并保持正确的尺寸?
问题描述
我正在使用 Omniglot 数据集,它是一组 19,280 张图像,每张图像都是 105 x 105(灰度)。
我使用以下转换定义了一个自定义数据集类:
class OmniglotDataset(Dataset):
def __init__(self, X, transform=None):
self.X = X
self.transform = transform
def __len__(self):
return self.X.shape[0]
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img = self.X[idx]
if self.transform:
img = self.transform(img)
return img
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
X_train.shape
(19280, 105, 105)
train_dataset = OmniglotDataset(X_train, transform=img_transform)
当我索引单个图像时,它会返回正确的尺寸:
train_dataset[0].shape
torch.Size([1, 105, 105])
但是当我索引几个图像时,它以错误的顺序返回尺寸(我期望3 x 105 x 105
):
train_dataset[[1,2,3]].shape
torch.Size([105, 3, 105])
解决方案
您收到错误是因为尝试将单个图像的转换应用于列表:
获取任意大小的批次的更方便的方法是使用 Dataloader:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
omniglot = datasets.Omniglot(root='./data', background=True, download=True, transform = img_transform)
data_loader = DataLoader(omniglot, shuffle=False, batch_size = 8)
for image_batch in data_loader:
# now image_batch contain first eight samples
print(image_batch.shape) # torch.Size([8, 1, 105, 105])
break
如果您确实需要以任意顺序获取图像:
from operator import itemgetter
indexes = [1,3,5]
selected_samples = itemgetter(*b)(omniglot)