python - 如何定义批处理生成器?
问题描述
我有一个包含大约一百万张图片的目录。我想创建一个batch_generator
这样我就可以训练我的 CNN,因为我不能一次将所有这些图像保存在内存中。
所以,我写了一个生成器函数来做到这一点:
def batch_generator(image_paths, batch_size, isTraining):
while True:
batch_imgs = []
batch_labels = []
type_dir = 'train' if isTraining else 'test'
for i in range(len(image_paths)):
print(i)
print(os.path.join(data_dir_base, type_dir, image_paths[i]))
img = cv2.imread(os.path.join(data_dir_base, type_dir, image_paths[i]), 0)
img = np.divide(img, 255)
img = img.reshape(28, 28, 1)
batch_imgs.append(img)
label = image_paths[i].split('_')[1].split('.')[0]
batch_labels.append(label)
if len(batch_imgs) == batch_size:
yield (np.asarray(batch_imgs), np.asarray(batch_labels))
batch_imgs = []
if batch_imgs:
yield batch_imgs
当我调用此语句时:
index = next(batch_generator(train_dataset, 10, True))
它正在打印相同的索引值和路径,因此,它在每次调用next()
. 我该如何解决?
我用这个问题作为代码的参考:how to split an iterable in constant-size chunks
解决方案
# batch generator
def get_batches(dataset, batch_size):
X, Y = dataset
n_samples = X.shape[0]
# Shuffle at the start of epoch
indices = np.arange(n_samples)
np.random.shuffle(indices)
for start in range(0, n_samples, batch_size):
end = min(start + batch_size, n_samples)
batch_idx = indices[start:end]
yield X[batch_idx], Y[batch_idx]
推荐阅读
- ruby-on-rails - Jekyll ERROR Errno::ECONNRESET: Connection reset by peer @ io_fillbuf - fd:17
- python - 在 RelativeLayout 中嵌入 GridLayout
- karate - url中的空手道框架符号编码
- java - Java中的可选参数
- javascript - 为什么输出是 1 2 3 13 ,背后的逻辑是什么?
- python - 如何在使用 python selenium 键入文本后从搜索建议中获取价值?
- netflix-eureka - SpringCloud Eureka 心跳与健康检查
- javascript - Uncaught SyntaxError: missing ) 在尝试添加函数以对数据进行排序时在参数列表之后
- python - Python/Plotly:如何从 OLS 行中提取“m”和“b”?
- r - 基于 R 中的现有列创建新列