首页 > 解决方案 > 将自定义标签添加到 pytorch 数据加载器/数据集不适用于自定义数据集

问题描述

我正在参加 Kaggle 上的仙人掌图像比赛,我正在尝试将 PyTorch 数据加载器用于我的 CNN。但是,我遇到了无法为训练集设置标签的问题。训练集图像在文件夹中,标签在 csv 文件中。这是我的代码。

 train = torchvision.datasets.ImageFolder(root='../input/train', 
 transform=transform)

 train.targets = torch.from_numpy(df['has_cactus'].values)

 train_loader = torch.utils.data.DataLoader(train, batch_size=64, shuffle=True, num_workers=2)

 for i, data in enumerate(train_loader, 0):
     print(data[1])

此代码输出全为零的批量张量,这显然是不正确的,因为绝大多数标签(如果您要查看数据框)都是标签。我相信这是将标签分配给“train.targets”的问题。如果在分配其他标签之前打印“train.targets”,它会返回一个全为零的张量,这与我得到的不正确结果一致。我该如何解决这个问题?

标签: pythonmachine-learningcomputer-visionpytorch

解决方案


我通常继承内置的 DataSet 类,如下所示:

from torch.utils.data import DataLoader
class DataSet:

    def __init__(self, root):
        """Init function should not do any heavy lifting, but
            must initialize how many items are available in this data set.
        """

        self.ROOT = root
        self.images = read_images(root + "/images")
        self.labels = read_labels(root + "/labels")

    def __len__(self):
        """return number of points in our dataset"""

        return len(self.images)

    def __getitem__(self, idx):
        """ Here we have to return the item requested by `idx`
            The PyTorch DataLoader class will use this method to make an iterable for
            our training or validation loop.
        """

        img = images[idx]
        label = labels[idx]

        return img, label

现在,您可以创建此类的一个实例,

ds = Dataset('../input/train')

现在,您可以实例化 DataLoader:

dl = DataLoader(ds, batch_size=TRAIN_BATCH_SIZE, shuffle=False, num_workers=4, drop_last=True)

这将创建一批您可以访问的数据:

for image, label in dl:
    print(label)

推荐阅读