首页 > 解决方案 > 使用 tensoflow 在 CNN 中加载大型数据集的最佳方法

问题描述

晚上好,我正在使用 tensorflow 1.4 在 CNN 中进行 Cat-Dog 分类,我想知道在没有内存不足问题的情况下将数据加载到 CNN 中的最佳方法是什么,知道数据集的大小(8000图像用于训练,2000 图像用于测试)。
我尝试通过在 opencv 中读取数据并将其转换为 numpy 数组并使用该 numpy 数组进行训练和测试来加载数据,但这种方法似乎存在内存问题(代码如下所示)。我也尝试过减少批量大小,但它似乎不起作用,有没有办法从磁盘加载数据?所以请您指导我找到一个带有实现或代码的清晰解决方案。这是我用于加载数据的代码:
这是一个标记数据的函数

# function that labels the data
def label_img(img):
    word_label = img.split('.')[-3]
    # DIY One hot encoder
    if word_label == 'cat':
        return [1, 0]
    elif word_label == 'dog':
        return [0, 1]


这是一个创建训练数据的函数,相同的函数用于测试数据。

def create_train_data():
    # Creating an empty list where we should store the training data
    # after a little preprocessing of the data
    training_data = []

    # tqdm is only used for interactive loading
    # loading the training data
    for img in tqdm(os.listdir(TRAIN_DIR)):
        # labeling the images
        label = label_img(img)

        path = os.path.join(TRAIN_DIR, img)

        # loading the image from the path and then converting them into
        # greyscale for easier covnet prob
        img = cv2.imread(path)
        img = cv2.resize(img, (256, 256))

        training_data.append([np.array(img), np.array(label)])

        # shuffling of the training data to preserve the random state of our data
    random.seed(101)
    random.shuffle(training_data)

    return training_data


这是包含批处理图像的会话代码,
这里 TrainX 和 TrainY 只是包含图像和标签的 nparray

with tf.Session() as sess:
    sess.run(init)
    train_loss = []
    test_loss = []
    train_accuracy = []
    test_accuracy = []
    for i in range(training_iters):
        for batch in range(len(trainX)//batch_size):
            batch_x = trainX[batch*batch_size:min((batch+1)*batch_size, len(trainX))]
            batch_y = trainY[batch*batch_size:min((batch+1)*batch_size, len(trainY))]
            # Run optimization op (backprop).
            # Calculate batch loss and accuracy
            opt = sess.run(optimizer, feed_dict={x: batch_x,
                                                              y: batch_y})
            loss, acc = sess.run([cost, accuracy], feed_dict={x: batch_x,
                                                              y: batch_y})
        print("Iter " + str(i) + ", Loss= " + \
                      "{:.6f}".format(loss) + ", Training Accuracy= " + \
                      "{:.5f}".format(acc))
        print("Optimization Finished!")

        # Calculate accuracy for all 10000 mnist test images
        test_acc,valid_loss = sess.run([accuracy, cost], feed_dict={x: testX, y : testY})
        train_loss.append(loss)
        test_loss.append(valid_loss)
        train_accuracy.append(acc)
        test_accuracy.append(test_acc)
        print("Testing Accuracy:", "{:.5f}".format(test_acc))

标签: pythontensorflow

解决方案


将数千张图像加载到内存中并不是正确的方法。

因此,这就是 Tensorflow 具有用于加载正确数据的内置 API 的原因。您可以在此处找到有关加载图像的更多信息:https ://www.tensorflow.org/tutorials/load_data/images#load_using_tfdata

来自链接的示例:

  1. 加载训练数据生成器,它返回一个迭代器
train_data_gen = image_generator.flow_from_directory(directory=str(data_dir),
                                                     batch_size=BATCH_SIZE,
                                                     shuffle=True,
                                                     target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                     classes = list(CLASS_NAMES))
  1. 获取训练批次
image_batch, label_batch = next(train_data_gen)

使用这个想法,您可以对自己的代码进行更改


推荐阅读