tensorflow - 如何在 Keras 中调用方法作为自定义回调?
问题描述
我需要在每 5K 次迭代后运行以下方法。
def evaluation_matrix(path_true,path_pred):
print(path_true,"\n",path_pred)
true_list_new, pred_list_new = read_from_folder(path_true = path_true , path_pred = path_pred)
try:
scikit_metrix(true_list_new = true_list_new,pred_list_new = pred_list_new)
except:
print("An exception occurred")
我希望将它用作 model.fit_generator 函数中的回调。如何做到这一点?那就是传参+5K区间?
history = model.fit_generator(generator = myGene, steps_per_epoch=steps_per_epoch, epochs=epoch, verbose = 1, callbacks=[],shuffle=True)
解决方案
自定义回调是在训练、评估或推理期间自定义 Keras 模型行为的强大工具。
下面是一个示例,我们在每个 epoch 之后计算梯度。同样,您可以使用许多内置方法进行更多自定义。您可以在此处找到更多相关信息 - https://www.tensorflow.org/guide/keras/custom_callback
注意:我使用的是 tensorflow 1.15.0
# (1) Importing dependency
import tensorflow as tf
import keras
from keras import backend as K
from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout, Flatten, Conv2D, MaxPooling2D
from keras.layers.normalization import BatchNormalization
import numpy as np
np.random.seed(1000)
# (2) Get Data
import tflearn.datasets.oxflower17 as oxflower17
x, y = oxflower17.load_data(one_hot=True)
# (3) Create a sequential model
model = Sequential()
# 1st Convolutional Layer
model.add(Conv2D(filters=96, input_shape=(224,224,3), kernel_size=(11,11), strides=(4,4), padding='valid'))
model.add(Activation('relu'))
# Pooling
model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='valid'))
# Batch Normalisation before passing it to the next layer
model.add(BatchNormalization())
# 2nd Convolutional Layer
model.add(Conv2D(filters=256, kernel_size=(11,11), strides=(1,1), padding='valid'))
model.add(Activation('relu'))
# Pooling
model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='valid'))
# Batch Normalisation
model.add(BatchNormalization())
# 3rd Convolutional Layer
model.add(Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding='valid'))
model.add(Activation('relu'))
# Batch Normalisation
model.add(BatchNormalization())
# 4th Convolutional Layer
model.add(Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding='valid'))
model.add(Activation('relu'))
# Batch Normalisation
model.add(BatchNormalization())
# 5th Convolutional Layer
model.add(Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding='valid'))
model.add(Activation('relu'))
# Pooling
model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='valid'))
# Batch Normalisation
model.add(BatchNormalization())
# Passing it to a dense layer
model.add(Flatten())
# 1st Dense Layer
model.add(Dense(4096, input_shape=(224*224*3,)))
model.add(Activation('relu'))
# Add Dropout to prevent overfitting
model.add(Dropout(0.4))
# Batch Normalisation
model.add(BatchNormalization())
# 2nd Dense Layer
model.add(Dense(4096))
model.add(Activation('relu'))
# Add Dropout
model.add(Dropout(0.4))
# Batch Normalisation
model.add(BatchNormalization())
# 3rd Dense Layer
model.add(Dense(1000))
model.add(Activation('relu'))
# Add Dropout
model.add(Dropout(0.4))
# Batch Normalisation
model.add(BatchNormalization())
# Output Layer
model.add(Dense(17))
model.add(Activation('softmax'))
model.summary()
# (4) Compile
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
epoch_gradient = []
# Define the Required Callback Function
class GradientCalcCallback(tf.keras.callbacks.Callback):
def get_gradient_func(model):
grads = K.gradients(model.total_loss, model.trainable_weights)
inputs = model.model._feed_inputs + model.model._feed_targets + model.model._feed_sample_weights
func = K.function(inputs, grads)
return func
def on_epoch_end(self, epoch, logs=None):
get_gradient = get_gradient_func(model)
grads = get_gradient([x, y, np.ones(len(y))])
epoch_gradient.append(grads)
model.fit(x, y, batch_size=64, epochs= 4, verbose=1, validation_split=0.2, shuffle=True, callbacks=[GradientCalcCallback()])
# (7) Convert to a 2 dimensiaonal array of (epoch, gradients) type
gradient = np.asarray(epoch_gradient)
print("Total number of epochs run:", epoch)
print("Gradient Array has the shape:",gradient.shape)
推荐阅读
- sql - 寻找连续 4 年捐赠的客户(包括差距)
- amazon-web-services - Snowflake 是否将所有数据存储在 S3 中?
- openmpi - OpenMPI 如何设置进程的排名
- python - 在 __main__ 函数 python 之外调用函数
- python - 查找数组的局部最大值及其位置Python的问题
- python - 从列表 python 中分配 valueas
- javascript - 如何使用 vanilla javascript 触发 css 动画
- java - 单向多对一关系 Spring-Boot JPA 映射
- c++ - 设置窗口透明
- loops - 在 jmeter 中的循环控制器之间添加等待