python - Keras - 数据集的数据生成器太大而无法放入内存
问题描述
我正在处理 388 幅 3D MRI 图像,这些图像太大而无法容纳训练 CNN 模型时可用的内存,因此我选择创建一个生成器,该生成器将批量图像接收到内存中以一次训练并将其与用于 3D 图像的自定义 ImageDataGenerator(为 github 下载)。我正在尝试使用 MRI 图像预测单个测试分数(范围 1-30)。我有以下生成器代码,我不确定它是否正确:
x = np.asarray(img)
y = np.asarray(scores)
def create_batch(x, y, batch_size):
x, y = shuffle(x, y)
x_split, x_val, y_split, y_val = train_test_split(x, y, test_size=.05, shuffle=True)
x_batch, x_test, y_batch, y_test = train_test_split(x_split, y_split, test_size=.05, shuffle=True)
x_train, y_train = [], []
num_batches = len(x_batch)//batch_size
for i in range(num_batches):
x_train.append([x_batch[0:batch_size]])
y_train.append([y_batch[0:batch_size]])
return x_train, y_train, x_val, y_val, x_batch, y_batch, x_test, y_test, num_batches
epochs = 1
model = build_model(input_size)
x_train, y_train, x_val, y_val, x_batch, y_batch, x_test, y_test, num_batches = create_batch(x, y, batch_size)
train_datagen = customImageDataGenerator(shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
val_datagen = customImageDataGenerator()
validation_set = val_datagen.flow(x_val, y_val, batch_size=batch_size, shuffle=False)
def generator(batch_size, epochs):
for e in range(epochs):
print('Epoch', e+1)
batches = 0
images_fitted = 0
for i in range(num_batches):
training_set = train_datagen.flow(x_train[i][0], y_train[i][0], batch_size=batch_size, shuffle=False)
images_fitted += len(x_train[i][0])
total_images = len(x_batch)
print('number of images used: %s/%s' % (images_fitted, total_images))
history = model.fit_generator(training_set,
steps_per_epoch = 1,
#callbacks = [earlystop],
validation_data = validation_set,
validation_steps = 1)
model.load_weights('jesse_weights_13layers.h5')
batches += 1
yield history
if batches >= num_batches:
break
return model
def train_load_weights():
history = generator(batch_size, epochs)
for e in range(epochs):
for i in range(num_batches):
print(next(history))
model.save_weights('jesse_weights_13layers.h5')
for i in range(1):
print('Run', i+1)
train_load_weights()
我不确定生成器是否正确构建,或者模型是否被正确训练并且不知道如何检查它是否正确。如果有人有任何建议,我将不胜感激!代码运行,这是训练的一部分:
Run 1
Epoch 1
number of images used: 8/349
Epoch 1/1
1/1 [==============================] - 156s 156s/step - loss: 8.0850 - accuracy: 0.0000e+00 - val_loss: 10.8686 - val_accuracy: 0.0000e+00
<keras.callbacks.callbacks.History object at 0x00000269A4B4E848>
number of images used: 16/349
Epoch 1/1
1/1 [==============================] - 154s 154s/step - loss: 4.3460 - accuracy: 0.0000e+00 - val_loss: 4.5994 - val_accuracy: 0.0000e+00
<keras.callbacks.callbacks.History object at 0x0000026899A96708>
number of images used: 24/349
Epoch 1/1
1/1 [==============================] - 148s 148s/step - loss: 4.1174 - accuracy: 0.0000e+00 - val_loss: 4.6038 - val_accuracy: 0.0000e+00
<keras.callbacks.callbacks.History object at 0x00000269A4F2F488>
number of images used: 32/349
Epoch 1/1
1/1 [==============================] - 151s 151s/step - loss: 4.2788 - accuracy: 0.0000e+00 - val_loss: 4.6029 - val_accuracy: 0.0000e+00
<keras.callbacks.callbacks.History object at 0x00000269A4F34D08>
number of images used: 40/349
Epoch 1/1
1/1 [==============================] - 152s 152s/step - loss: 3.9328 - accuracy: 0.0000e+00 - val_loss: 4.6057 - val_accuracy: 0.0000e+00
<keras.callbacks.callbacks.History object at 0x00000269A4F57848>
number of images used: 48/349
Epoch 1/1
1/1 [==============================] - 154s 154s/step - loss: 3.9423 - accuracy: 0.0000e+00 - val_loss: 4.6077 - val_accuracy: 0.0000e+00
<keras.callbacks.callbacks.History object at 0x00000269A4F4D888>
number of images used: 56/349
Epoch 1/1
1/1 [==============================] - 160s 160s/step - loss: 3.7610 - accuracy: 0.0000e+00 - val_loss: 4.6078 - val_accuracy: 0.0000e+00
<keras.callbacks.callbacks.History object at 0x00000269A4F3E4C8>
number of images used: 64/349
解决方案
不确定您的目录的结构,但如果它是这样的:
|---train
|------class1
|---------1.jpg
|---------2.jpg
|------class2
|---------3.jpg
|..........
|---test
|----label
|---------t1.jpg
|---------t2.jpg
注意:“test”后面有一个子文件夹
那么这是如何使用 ImageDataGenerator 的:
generator = ImageDataGenerator(..., validation_split=...) # for train and valid, augment data here too
train_gen = generator.flow_from_directory("<path_to_train>/train", batch_size=...,target_size=..., subset="training")
valid_gen = generator.flow_from_directory("<path_to_train>/train", batch_size=...,target_size=..., subset="validation)"
test_generator = ImageDataGenerator(...) # no validation split
test_gen = test_generator.flow_from_directory("<path_to_test>/test", class_mode="None",...)
然后只需调用:
model.fit(train_gen, validation_data=valid_gen,...)
model.predict(test_gen)
推荐阅读
- sql - 如果存在NULL,则排除多个条目的内部联接查询
- android - 单击显示旧数据时使用 RecyclerView 搜索视图
- python - String Join 将 True 视为布尔值而不是字符串
- mysql - 如何将kafka连接到mysql?
- azure - 如何确保通过 Azure 服务总线队列接收我的所有消息?
- go - 练习解决方案中的内存泄漏:等效二叉树?
- shiro - 用于 mongodb 的 Apache shiro ini 文件
- sql-server - 在 bigquery 中交叉应用等效的 MSSQL 运算符
- java - OptaPlanner 是否对链式计划实体上的 ChangeMove 进行了错误的说明?
- spring-cloud - 如何使用 Feign Client 下载 pdf 文件?