首页 > 解决方案 > 我在这里缺少什么,使用 ImageFolder 获取完整文件夹名称作为 MNIST 双数据集图像的标签?

问题描述

我想使用dataset.ImageFolder创建一个图像数据集。

我当前的图像目录结构如下所示:

1在火车图像中,我的标签包含 00、01 等子文件夹。在每个文件夹中,图像包含与每个标签对应的两位数字

这是我使用的代码,后面是标签没有的输出。与图像匹配

这里的路径

data_dir = "/home/mhamdan/hamdan/MNIST_muldigits/data/double_mnist"
train_dir = data_dir + '/train' # training_set contains training dataset
val_dir = data_dir + '/val'  #contains validation dataset
test_dir = data_dir + '/test'  #contains test dataset

在此处加载数据

#Load the dataset with Image Folder
trainset = datasets.ImageFolder(train_dir, transform = transformation)
valset = datasets.ImageFolder(val_dir, transform = transformation)
testset = datasets.ImageFolder(test_dir, transform = transformation)

数据加载器

#define data loaders
batch_size = 32
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True,num_workers=2)
val_loader = DataLoader(valset, batch_size=batch_size, shuffle=True,num_workers=2)
test_loader = DataLoader(testset, batch_size=batch_size,num_workers=1)

这是随机训练图像的绘图

examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)
import matplotlib.pyplot as plt

fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
  plt.title("Ground Truth: {}".format(example_targets[1]))
  plt.xticks([])
  plt.yticks([])
fig

如您在此处看到的,标签与图像不同 标签与图像不同

每个子文件夹都包含一个与标签关联的唯一标签 子文件夹 iamge

这里是 01 子目录中的图像01

使用索引后的最后更新。还是同样的问题,更新索引后

标签: python-3.xlabelpytorch-dataloader

解决方案


这是通过将索引作为字典labelsdec = trainset.class_to_idx并使用此函数将键提取为标签/类来解决我的问题

def getList(dict):
list = []
for key in dict.keys():
    list.append(key)
      
return list
def getList(dict):
list = []
for key in dict.keys():
    list.append(key)
      
return list
  
classes = getList(labelsdec)
  

细化 10 个图像:

def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    plt.imshow(np.transpose(img, (1, 2, 0)))  # convert from Tensor image
# obtain one batch of training images
data_iter = iter(train_loader)
images, lbls = data_iter.next()
images = images.numpy() # convert images to numpy for display
# plot the images in the batch, along with the corresponding labels
fig = plt.figure(figsize=(10, 4))
# display 20 images
for idx in np.arange(10):
    ax = fig.add_subplot(2, 10/2, idx+1, xticks=[], yticks=[])
    imshow(images[idx])
    label = lbls[idx]
    ax.set_title(classes[lbls[idx]])

这是它的外观看图片子目录的标签与图像内容匹配的所需输出


推荐阅读