python-3.x - 我在这里缺少什么,使用 ImageFolder 获取完整文件夹名称作为 MNIST 双数据集图像的标签?
问题描述
我想使用dataset.ImageFolder创建一个图像数据集。
我当前的图像目录结构如下所示:
1:在火车图像中,我的标签包含 00、01 等子文件夹。在每个文件夹中,图像包含与每个标签对应的两位数字
这是我使用的代码,后面是标签没有的输出。与图像匹配
这里的路径
data_dir = "/home/mhamdan/hamdan/MNIST_muldigits/data/double_mnist"
train_dir = data_dir + '/train' # training_set contains training dataset
val_dir = data_dir + '/val' #contains validation dataset
test_dir = data_dir + '/test' #contains test dataset
在此处加载数据
#Load the dataset with Image Folder
trainset = datasets.ImageFolder(train_dir, transform = transformation)
valset = datasets.ImageFolder(val_dir, transform = transformation)
testset = datasets.ImageFolder(test_dir, transform = transformation)
数据加载器
#define data loaders
batch_size = 32
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True,num_workers=2)
val_loader = DataLoader(valset, batch_size=batch_size, shuffle=True,num_workers=2)
test_loader = DataLoader(testset, batch_size=batch_size,num_workers=1)
这是随机训练图像的绘图
examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)
import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title("Ground Truth: {}".format(example_targets[1]))
plt.xticks([])
plt.yticks([])
fig
如您在此处看到的,标签与图像不同 标签与图像不同
解决方案
这是通过将索引作为字典labelsdec = trainset.class_to_idx
并使用此函数将键提取为标签/类来解决我的问题
def getList(dict):
list = []
for key in dict.keys():
list.append(key)
return list
def getList(dict):
list = []
for key in dict.keys():
list.append(key)
return list
classes = getList(labelsdec)
细化 10 个图像:
def imshow(img):
img = img / 2 + 0.5 # unnormalize
plt.imshow(np.transpose(img, (1, 2, 0))) # convert from Tensor image
# obtain one batch of training images
data_iter = iter(train_loader)
images, lbls = data_iter.next()
images = images.numpy() # convert images to numpy for display
# plot the images in the batch, along with the corresponding labels
fig = plt.figure(figsize=(10, 4))
# display 20 images
for idx in np.arange(10):
ax = fig.add_subplot(2, 10/2, idx+1, xticks=[], yticks=[])
imshow(images[idx])
label = lbls[idx]
ax.set_title(classes[lbls[idx]])
推荐阅读
- reactjs - Using keyof in Typescript to get rid of 'element implicitly has any type' error
- scroll - Why is there no vertical scroll bar on my markdown file?
- java - Using BigQuery Storage API for small but frequent queries
- google-apps-script - Error getting the last row of a sheet in Google Sheet archiving macro
- python - Python 3 | How to print the dataframes by their names once they are created dynamically?
- windows - SQL Server Management Studio 安装程序 exe 无法启动
- r - 在 R 中使用 terra 包从不同程度的栅格镶嵌错误?
- git - 如何获取特定版本的 Chromium 源代码?
- swift - 如何在swift中进行函数回调
- node.js - 如何解决 Google App Engine 上 VueJs / Express / Mongo 应用程序的 API 响应的“启用 Javascript”错误