python - Keras 内存分配问题
问题描述
我一直在尝试构建一个深度学习模型,使用 Keras 作为 API 与 tensorflow 2.2 和 cuda 102.89。我的数据集“相对较大”(27500 张 400x400 图像),因此我尝试使用keras.utils.sequence
(此处)和tf.keras.preprocessing.image.ImageDataGenerator
(此处)将这些图像批量放入内存中。然而,到目前为止,这些都不起作用,我一定没有正确使用它,但我看不出我的代码有什么问题。
使用keras.utils.sequence
:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from tensorflow import keras
class BatchGenerator(keras.utils.Sequence):
def __init__(self, inputs, targets, batch_size=32):
self.inputs = inputs
self.targets = targets
self.batch_size = batch_size
def __len__(self):
return len(self.targets) // self.batch_size
def __getitem__(self, idx):
i = idx * self.batch_size
x = self.inputs[i : i + self.batch_size]
y = self.targets[i : i + self.batch_size]
return x, y
# Load images and labels with normalization and one hot encoding
# images: np.array, (27512, 480, 480, 3); labels: np.array, (27512, 480, 480, 9); nb_class = 9
images, labels, nb_class = data_loader('../Data/Images', '../Data/labels')
# Creation of the 0.2 validation split
train = list(range(images.shape[0]//5, images.shape[0]))
test = list(range(0, images.shape[0]//5))
# Batch generation
train_gen = BatchGenerator(images[train], labels[train], batch_size, data_aug)
valid_gen = BatchGenerator(images[test], labels[test], batch_size, data_aug)
# Model fitting
history = model.fit(train_gen, epochs=nb_epoch, verbose=1,
validation_data=valid_gen, shuffle=False, callbacks=callbacks)
使用ImageDataGenerator
链接和flow
链接:
from keras.preprocessing.image import ImageDataGenerator
# Load images and labels with normalization and one hot encoding
# images: np.array, (27512, 480, 480, 3); labels: np.array, (27512, 480, 480, 9); nb_class = 9
images, labels, nb_class = data_loader('../Data/Images', '../Data/labels')
# Batch generation with data augmentation instance
datagen = ImageDataGenerator(
vertical_flip=True,
horizontal_flip=True,
preprocessing_function=augmentations_color,
samplewise_center=True,
samplewise_std_normalization=True,
validation_split=0.2
)
# Batch generation application
train_gen = datagen.flow(images, labels, batch_size=16,
shuffle=True, subset='training')
valid_gen = datagen.flow(images, labels, batch_size=16,
shuffle=True, subset='validation')
# Model fitting
history = model.fit(train_gen, epochs=nb_epoch, verbose=1,
validation_data=valid_gen, shuffle=False, callbacks=callbacks)
在这两种情况下,代码在模型拟合部分期间崩溃BatchGenerator
并且永远不会到达抱怨:
Traceback (most recent call last):
File "main.py", line 190, in <module>
conf_file=train_set.cfg_path)
File "main.py", line 92, in main
shuffle=True, subset='training')
File "/hpc_htom/kjam268/Virtual_ENV/HistoTAG/lib/python3.6/site-packages/keras_preprocessing/image/image_data_generator.py", line 434, in flow
dtype=self.dtype
File "/hpc_htom/kjam268/Virtual_ENV/HistoTAG/lib/python3.6/site-packages/keras_preprocessing/image/numpy_array_iterator.py", line 103, in __init__
np.unique(y[split_idx:]))):
File "<__array_function__ internals>", line 6, in unique
File "/hpc_htom/kjam268/Virtual_ENV/HistoTAG/lib/python3.6/site-packages/numpy/lib/arraysetops.py", line 261, in unique
ret = _unique1d(ar, return_index, return_inverse, return_counts)
File "/hpc_htom/kjam268/Virtual_ENV/HistoTAG/lib/python3.6/site-packages/numpy/lib/arraysetops.py", line 314, in _unique1d
ar = np.asanyarray(ar).flatten()
numpy.core._exceptions.MemoryError: Unable to allocate 189. GiB for an array with shape (50711040000,) and data type float32
我已将代码精简到最重要的部分,但如果需要,我可以发布所有代码。我尝试了这两种场景,不同批次的批次小到 1,但没有改进。
关于我可以尝试训练我的模型的任何想法?
谢谢
解决方案
您可以按照这个官方说明使用类似 U-Net 的架构进行图像分割。这是与您的案例最相关的编码部分。
发电机
class BatchGenerator(keras.utils.Sequence):
"""Helper to iterate over the data (as Numpy arrays)."""
def __init__(self, batch_size, img_size, input_img_paths, target_img_paths):
self.batch_size = batch_size
self.img_size = img_size
self.input_img_paths = input_img_paths
self.target_img_paths = target_img_paths
def __len__(self):
return len(self.target_img_paths) // self.batch_size
def __getitem__(self, idx):
"""Returns tuple (input, target) correspond to batch #idx."""
i = idx * self.batch_size
batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
batch_target_img_paths = self.target_img_paths[i : i + self.batch_size]
x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="float32")
for j, path in enumerate(batch_input_img_paths):
img = load_img(path, target_size=self.img_size)
x[j] = img
y = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="uint8")
for j, path in enumerate(batch_target_img_paths):
img = load_img(path, target_size=self.img_size, color_mode="grayscale")
y[j] = np.expand_dims(img, 2)
# Ground truth labels are 1, 2, 3. Subtract one to make them 0, 1, 2:
y[j] -= 1
return x, y
数据集
input_dir = "images/"
target_dir = "annotations/trimaps/"
batch_size = 32
input_img_paths = sorted(
[
os.path.join(input_dir, fname)
for fname in os.listdir(input_dir)
if fname.endswith(".jpg")
]
)
target_img_paths = sorted(
[
os.path.join(target_dir, fname)
for fname in os.listdir(target_dir)
if fname.endswith(".png") and not fname.startswith(".")
]
)
print("Number of samples:", len(input_img_paths))
for input_path, target_path in zip(input_img_paths[:10], target_img_paths[:10]):
print(input_path, "|", target_path)
数据生成器
# Split our img paths into a training and a validation set
val_samples = 1000
random.Random(1337).shuffle(input_img_paths)
random.Random(1337).shuffle(target_img_paths)
train_input_img_paths = input_img_paths[:-val_samples]
train_target_img_paths = target_img_paths[:-val_samples]
val_input_img_paths = input_img_paths[-val_samples:]
val_target_img_paths = target_img_paths[-val_samples:]
# Instantiate data Sequences for each split
train_gen = BatchGenerator(
batch_size, img_size, train_input_img_paths, train_target_img_paths
)
val_gen = BatchGenerator(batch_size, img_size, val_input_img_paths,
val_target_img_paths)
推荐阅读
- assembly - 指令 ROL.L d0, d0 的目的是什么?
- prolog - 书本示例中的 Prolog 存在错误
- python - 如何在 Python 中更改 Pandas 数据框的结构?
- docker - ApolloServer:“无法连接到 websocket 端点 ws://localhost:4000/subscriptions。请检查端点 url 是否正确。”
- reactjs - webpack-dev-server 无法从 Visual Studio 加载 0.chunk.js
- rdf - 如何使用 OWL(Turtle 语法?
- flutter - 如何在 Flutter 应用中的所有屏幕上添加水印
- heroku - 如何在heroku上使用dlib预测器部署django rest API?
- parsing - 备注:如何在 MDAST 中解析 HTML 标签及其内容
- python - 具有多个输出 Python 的拟合函数