python - 使用 TensorFlow-Keras API 进行数据增强
问题描述
以下代码允许在每个 epoch 结束时将训练集的图像旋转 90º。
from skimage.io import imread
from skimage.transform import resize, rotate
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.utils import Sequence
from keras.models import Sequential
from keras.layers import Conv2D, Activation, Flatten, Dense
# Model architecture (dummy)
model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=(15, 15, 4)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
# Data iterator
class CIFAR10Sequence(Sequence):
def __init__(self, filenames, labels, batch_size):
self.filenames, self.labels = filenames, labels
self.batch_size = batch_size
self.angles = [0,90,180,270]
self.current_angle_idx = 0
# Method to loop throught the available angles
def change_angle(self):
self.current_angle_idx += 1
if self.current_angle_idx >= len(self.angles):
self.current_angle_idx = 0
def __len__(self):
return int(np.ceil(len(self.filenames) / float(self.batch_size)))
# read, resize and rotate the image and return a batch of images
def __getitem__(self, idx):
angle = self.angles[self.current_angle_idx]
print (f"Rotating Angle: {angle}")
batch_x = self.filenames[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array([
rotate(resize(imread(filename), (15, 15)), angle)
for filename in batch_x]), np.array(batch_y)
# Custom call back to hook into on epoch end
class CustomCallback(keras.callbacks.Callback):
def __init__(self, sequence):
self.sequence = sequence
# after end of each epoch change the rotation for next epoch
def on_epoch_end(self, epoch, logs=None):
self.sequence.change_angle()
# Create data reader
sequence = CIFAR10Sequence(["f1.PNG"]*10, [0, 1]*5, 8)
# fit the model and hook in the custom call back
model.fit(sequence, epochs=10, callbacks=[CustomCallback(sequence)])
如何修改代码以便在每个时期都发生图像的旋转?
期望的输出:
Epoch 1/10
Rotating Angle: 0
Rotating Angle: 90
Rotating Angle: 180
Rotating Angle: 270
Epoch 2/10
Rotating Angle: 0
Rotating Angle: 90
Rotating Angle: 180
Rotating Angle: 270
(...)
Epoch 10/10
Rotating Angle: 0
Rotating Angle: 90
Rotating Angle: 180
Rotating Angle: 270
换句话说,我如何编写一个在一个时期的“结束”运行的回调,它改变了角度值并在同一个时期继续训练(而不改变到下一个时期)?
提前致谢
注意:代码学分来自“mujjiga”。
解决方案
由于您有一个自定义序列生成器,您可以创建一个在纪元开始或结束时运行的函数。那是您可以放置代码来修改图像的地方。文档在[这里][1]
Epoch-level methods (training only)
on_epoch_begin(self, epoch, logs=None)
Called at the beginning of an epoch during training.
on_epoch_end(self, epoch, logs=None)
Called at the end of an epoch during training.
[1]: https://keras.io/guides/writing_your_own_callbacks/
推荐阅读
- swift - XCode/Swift:将 Bash 输出解析为 Textwindow 的函数
- node.js - 服务器启动时的 Mongoose ValidationError
- angular - Angular CLI 生成组件问题
- python - 将多个参数传递给异步 - Python 多处理
- module - 是否可以 require() 使用 luaL_loadstring() 加载的脚本?
- python - 在 python 请求中为 mailjet 添加变量
- javascript - 通过键盘输入 2 个数字并使用 Math 显示最大/最小和功率等级?
- javascript - vue-cli 3.0 - 输出库代码拆分时未定义 jsonpArray.push
- ruby-on-rails - ruby 的 link_to 方法不适用于 shopify 嵌入式应用程序
- sql-server - 删除几个表的外键上的 ON DELETE CASCADE