python - CIFAR10 子集上的自定义转换 - PyTorch
问题描述
我正在尝试为 CIFAR10 数据集的一部分创建自定义转换,该转换将图像叠加在数据集上。我能够下载数据并将其划分为子集。使用以下代码:
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
traindata = datasets.CIFAR10('./data', train=True, download=True,
transform= transform_train)
partitions = 5
traindata_split = torch.utils.data.random_split(traindata, [int(traindata.data.shape[0] / partitions) for _ in range(partitions)])
然后我想修改部分拆分,所以我创建了以下类和函数,如下所示:
class MyDataset(Dataset): # https://discuss.pytorch.org/t/torch-utils-data-dataset-random-split/32209/3
def __init__(self, subset, transform=None):
self.subset = subset
self.transform = transform
def __getitem__(self, index):
x, y = self.subset[index]
if self.transform:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.subset)
和
class ImageSuperImpose(object):
""" Image input as PIL and output as PIL
To be used as part of torchvision.transforms
Args: p, a threshold value to control image thinning
"""
def __init__(self, p=0):
self.p = p
def __call__(self, image):
img = cv2.imread('img.jpg')
img = img('float32')/255
imgSm = cv2.resize(img,(32,32))
np_arr = image.cpu().detach().numpy().T
sample = cv2.addWeighted(np_arr, 1, imgSm, 1, 0)
sample = sample.T
t = torch.from_numpy(sample)
return sample
transform_train2 = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
ImagePoisoning(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
datasetA = MyDataset(
traindata_split[0], transform= transform_train2
)
test_loader = torch.utils.data.DataLoader(datasetA, batch_size=128, shuffle=True)
但是当我尝试在子集上训练模型时,出现以下错误:
RuntimeError: The size of tensor a (32) must match the size of tensor b (3) at non-singleton dimension 0
** 更新** 这是给定的完整错误
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-20-7428084b03be> in <module>()
----> 1 train(model, opt, test_loader, 3)
9 frames
<ipython-input-14-fcb03e1d7685> in client_update(client_model, optimizer, train_loader, epoch)
5 client_model.train()
6 for e in range(epoch):
----> 7 for batch_idx, (data, target) in enumerate(train_loader):
8 data, target = data.to(device), target.to(device)
9 optimizer.zero_grad()
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
433 if self._sampler_iter is None:
434 self._reset()
--> 435 data = self._next_data()
436 self._num_yielded += 1
437 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
473 def _next_data(self):
474 index = self._next_index() # may raise StopIteration
--> 475 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
476 if self._pin_memory:
477 data = _utils.pin_memory.pin_memory(data)
/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
<ipython-input-7-1bde43acaff0> in __getitem__(self, index)
7 x, y = self.subset[index]
8 if self.transform:
----> 9 x = self.transform(x)
10 return x, y
11
/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)
65 def __call__(self, img):
66 for t in self.transforms:
---> 67 img = t(img)
68 return img
69
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in forward(self, tensor)
224 Tensor: Normalized Tensor image.
225 """
--> 226 return F.normalize(tensor, self.mean, self.std, self.inplace)
227
228 def __repr__(self):
/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py in normalize(tensor, mean, std, inplace)
282 if std.ndim == 1:
283 std = std.view(-1, 1, 1)
--> 284 tensor.sub_(mean).div_(std)
285 return tensor
286
RuntimeError: The size of tensor a (32) must match the size of tensor b (3) at non-singleton dimension 0
解决方案
推荐阅读
- swift - 在 SwiftUI 中子类化 TextField 并将状态传递给它
- javascript - 有没有办法将 className 传递给 react-markdown 中的组件?
- python - 重复向函数传递多个参数
- django - 静态文件夹中的 css 文件不适用于网页
- html - 为什么在使用背景图片时找不到我的封面图片:url
- python - 将相似的嵌套列表分组为一个列表
- javascript - 如何修复我的 api 项目的 429 请求错误?
- node.js - 使用会话文件存储随时间增加 CPU 峰值
- variables - Azure DevOps 变量输出到另一个作业导致空值
- symfony - Symfony5 服务:自动装配参数不起作用