首页 > 解决方案 > 如何定义批处理生成器?

问题描述

我有一个包含大约一百万张图片的目录。我想创建一个batch_generator这样我就可以训练我的 CNN,因为我不能一次将所有这些图像保存在内存中。

所以,我写了一个生成器函数来做到这一点:

def batch_generator(image_paths, batch_size, isTraining):
    while True:
        batch_imgs = []
        batch_labels = []
        
        type_dir = 'train' if isTraining else 'test'
        
        for i in range(len(image_paths)):
            print(i)
            print(os.path.join(data_dir_base, type_dir, image_paths[i]))
            img = cv2.imread(os.path.join(data_dir_base, type_dir, image_paths[i]), 0)
            img  = np.divide(img, 255)
            img = img.reshape(28, 28, 1)
            batch_imgs.append(img)
            label = image_paths[i].split('_')[1].split('.')[0]
            batch_labels.append(label)
            if len(batch_imgs) == batch_size:
                yield (np.asarray(batch_imgs), np.asarray(batch_labels))
                batch_imgs = []
        if batch_imgs:
            yield batch_imgs

当我调用此语句时:

index = next(batch_generator(train_dataset, 10, True))

它正在打印相同的索引值和路径,因此,它在每次调用next(). 我该如何解决?

我用这个问题作为代码的参考:how to split an iterable in constant-size chunks

标签: pythonpython-3.xgenerator

解决方案


# batch generator
def get_batches(dataset, batch_size):
    X, Y = dataset
    n_samples = X.shape[0]

    # Shuffle at the start of epoch
    indices = np.arange(n_samples)
    np.random.shuffle(indices)

    for start in range(0, n_samples, batch_size):
        end = min(start + batch_size, n_samples)

        batch_idx = indices[start:end]

        yield X[batch_idx], Y[batch_idx]

推荐阅读