首页 > 解决方案 > 导入图像(.jpg)数据集 Keras、Pandas 的正确方法

问题描述

在过去的几天里,我正在从事一个机器学习项目。

我有一个图像数据集(.jpg)。我有超过 500K 的图像。

除此之外,我有一个 CSV 文件,其中包含图像的名称(每个图像都有一个唯一的名称)和两个标签(目标值)。这两个目标标签完全不同,彼此之间没有任何关系。

我将为两个目标标签使用模型分离模型。

我的解决方案

  1. 将所有内容转换为一个大的 CSV 文件。类似于 CSV 格式的 MNIST 数据集。这种方法的问题是图像尺寸很大(我需要大图像)和三个通道(彩色图像)。所以 CSV 文件的大小变得超级大。

  2. 使用Keras ImageDataGenerator 和 flow_from_directory类。正如我之前提到的,我有两个标签(目标),所以需要创建同一个数据集的两个副本(因为 flow_from_directory 需要特定的数据结构)

现在,我的两个解决方案都有效,但存在特定问题。

我想知道有没有其他方法可以导入数据集。这样我就可以避免上面提到的问题。

我在这个项目中使用 Keras、Pandas、Numpy 和 Sklearn。我也可以自由使用任何其他库。

我没有在这个问题上附上我的解决方案的任何代码。如果需要,请告诉我。

Thnx阿布舍克

标签: pythonpandaskeras

解决方案


你提到了熊猫,但我认为这不能解决你的问题。

你为什么不写你自己的解决方案?

您可以尝试实现scikit-learn的方式。

识别手写数字为例,

示例代码

# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
# License: BSD 3 clause

import matplotlib.pyplot as plt    
# Import datasets, classifiers and performance metrics
from sklearn import datasets, svm, metrics

# The digits dataset 
digits = datasets.load_digits() # <--- right here

images_and_labels = list(zip(digits.images, digits.target))
for index, (image, label) in enumerate(images_and_labels[:4]):
    plt.subplot(2, 4, index + 1)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Training: %i' % label)

n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))

classifier = svm.SVC(gamma=0.001)

classifier.fit(data[:n_samples // 2], digits.target[:n_samples // 2])

expected = digits.target[n_samples // 2:]
predicted = classifier.predict(data[n_samples // 2:])

print("Classification report for classifier %s:\n%s\n"
      % (classifier, metrics.classification_report(expected, predicted)))
print("Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted))

images_and_predictions = list(zip(digits.images[n_samples // 2:], predicted))
for index, (image, prediction) in enumerate(images_and_predictions[:4]):
    plt.subplot(2, 4, index + 5)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Prediction: %i' % prediction)

plt.show()

源代码

scikit-learn构建一个dataset仅用于加载不同数据集的模块,例如 MNIST(图像和标签)。

您还将享受阅读dataset.load_digits()源代码的乐趣

简短整洁。希望你能找到更好的解决方案。


推荐阅读