首页 > 解决方案 > 如何向 CIFAR10 torchvision 添加新样品?

问题描述

嗨,我想将自己的图像添加到 torchvision 中的 CIFAR10 数据集,我该怎么做?

train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
train_data.add # or a workaround!

谢谢

标签: datasetpytorchtorchvisioncustom-dataset

解决方案


您可以使用此处CIFAR10的原始 cifar10 图像创建自定义数据集,或者您仍然可以在新的自定义数据集中使用数据集,然后在方法中添加您的逻辑。 这是一个简单的例子来帮助你:CIFAR10__getitem__()

class CIFAR10_2(torch.utils.data.Dataset):
    def __init__(self, dataset_path='/cifar10', transformations=None, should_download=True):
        self.dataset_train = torchvision.datasets.CIFAR10(dataset_path, download=should_download)
        self.transformations = transformations

    def __getitem__(self, index):
        # do as you wish , add your logic here
        (img, label) = self.dataset_train[index]
        # for transformations for example
        if self.transformations is not None:
            return self.transformations(img), label
        return img, label

    def __len__(self):
        return len(self.dataset_train)

您可以花哨并为测试、验证等添加逻辑,并做任何您喜欢的事情。


推荐阅读