python - 如何使用 tf.keras.utils.Sequence API 增强训练集?
问题描述
TensorFlow 文档有以下示例,可以说明如何创建批处理生成器,以便在训练集太大而无法放入内存时将训练集批量提供给模型:
from skimage.io import imread
from skimage.transform import resize
import tensorflow as tf
import numpy as np
import math
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class CIFAR10Sequence(tf.keras.utils.Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) *
self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) *
self.batch_size]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
我的目的是通过将每张图像旋转 3 倍 90º 来进一步增加训练集的多样性。在训练过程的每个 Epoch 中,模型将首先输入“0º 训练集”,然后分别输入 90º、180º 和 270º 旋转集。
如何修改前一段代码以在CIFAR10Sequence()
数据生成器中执行此操作?
请不要使用tf.keras.preprocessing.image.ImageDataGenerator()
,以免答案对另一种性质不同的类似问题失去普遍性。
注意:这个想法是在输入模型时“实时”创建新数据,而不是(提前)创建并在磁盘上存储一个新的和增强的训练集,该训练集大于原始训练集以供以后使用(也在批次)在模型的训练过程中。
提前谢谢
解决方案
使用自定义Callback
并挂钩到on_epoch_end
. 在每个 epoch 结束后改变数据迭代器对象的角度。
示例(内联记录)
from skimage.io import imread
from skimage.transform import resize, rotate
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.utils import Sequence
from keras.models import Sequential
from keras.layers import Conv2D, Activation, Flatten, Dense
# Model architecture (dummy)
model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=(15, 15, 4)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
# Data iterator
class CIFAR10Sequence(Sequence):
def __init__(self, filenames, labels, batch_size):
self.filenames, self.labels = filenames, labels
self.batch_size = batch_size
self.angles = [0,90,180,270]
self.current_angle_idx = 0
# Method to loop throught the available angles
def change_angle(self):
self.current_angle_idx += 1
if self.current_angle_idx >= len(self.angles):
self.current_angle_idx = 0
def __len__(self):
return int(np.ceil(len(self.filenames) / float(self.batch_size)))
# read, resize and rotate the image and return a batch of images
def __getitem__(self, idx):
angle = self.angles[self.current_angle_idx]
print (f"Rotating Angle: {angle}")
batch_x = self.filenames[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array([
rotate(resize(imread(filename), (15, 15)), angle)
for filename in batch_x]), np.array(batch_y)
# Custom call back to hook into on epoch end
class CustomCallback(keras.callbacks.Callback):
def __init__(self, sequence):
self.sequence = sequence
# after end of each epoch change the rotation for next epoch
def on_epoch_end(self, epoch, logs=None):
self.sequence.change_angle()
# Create data reader
sequence = CIFAR10Sequence(["f1.PNG"]*10, [0, 1]*5, 8)
# fit the model and hook in the custom call back
model.fit(sequence, epochs=10, callbacks=[CustomCallback(sequence)])
输出:
Rotating Angle: 0
Epoch 1/10
Rotating Angle: 0
Rotating Angle: 0
2/2 [==============================] - 2s 755ms/step - loss: 1.0153 - accuracy: 0.5000
Epoch 2/10
Rotating Angle: 90
Rotating Angle: 90
2/2 [==============================] - 0s 190ms/step - loss: 0.6975 - accuracy: 0.5000
Epoch 3/10
Rotating Angle: 180
Rotating Angle: 180
2/2 [==============================] - 2s 772ms/step - loss: 0.6931 - accuracy: 0.5000
Epoch 4/10
Rotating Angle: 270
Rotating Angle: 270
2/2 [==============================] - 0s 197ms/step - loss: 0.6931 - accuracy: 0.5000
Epoch 5/10
Rotating Angle: 0
Rotating Angle: 0
2/2 [==============================] - 0s 189ms/step - loss: 0.6931 - accuracy: 0.5000
Epoch 6/10
Rotating Angle: 90
Rotating Angle: 90
2/2 [==============================] - 2s 757ms/step - loss: 0.6932 - accuracy: 0.5000
Epoch 7/10
Rotating Angle: 180
Rotating Angle: 180
2/2 [==============================] - 2s 757ms/step - loss: 0.6931 - accuracy: 0.5000
Epoch 8/10
Rotating Angle: 270
Rotating Angle: 270
2/2 [==============================] - 2s 761ms/step - loss: 0.6932 - accuracy: 0.5000
Epoch 9/10
Rotating Angle: 0
Rotating Angle: 0
2/2 [==============================] - 1s 744ms/step - loss: 0.6932 - accuracy: 0.5000
Epoch 10/10
Rotating Angle: 90
Rotating Angle: 90
2/2 [==============================] - 0s 192ms/step - loss: 0.6931 - accuracy: 0.5000
<tensorflow.python.keras.callbacks.History at 0x7fcbdf8bcdd8>
推荐阅读
- python - pygame 看起来像一个文件名,但该文件不存在
- variables - fortran:实变量(i,j)做了什么
- javascript - 从 node.js 中的外部文件读取
- azure-service-fabric - 没有对集合进行任何更改的 CommitAsync() 的缺点
- javascript - 活动时在另一个元素上添加类时的数据属性
- javascript - 在外部 JavaScript 文件中查找用户控件的客户端 ID
- ios - 保持内存中的核心数据对象不被不相关的代码删除?
- typescript - TypeScript如何区分引用相同类型的类型别名
- ansible - 带有 uri 模块的 Ansible“local_action”发布带有变量的 json 正文
- python - 如何在 gmail 中验证凭据?