tensorflow - 如果验证准确性比上一个时期有所提高,如何编写自定义回调以在每个时期保存模型
问题描述
以下是我编写的自定义回调函数,但它不起作用:
class bestval(tf.keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.history={'loss': [],'acc': [],'val_loss': [],'val_acc': []}
def on_epoch_end(self, epoch, logs={}):
#appending val_acc in history
if logs.get('val_acc', -1) != -1:
self.history['val_acc'].append(logs.get('val_acc'))
# Trying to compare current epoch val_acc with all the values in self.history['val_acc']
if logs.get('val_acc')> [i for i in self.history['val_acc']]:
filepath="model_save/weights-{epoch:02d}-{val_acc:.4f}.hdf5"
# Saving the model using TF built-in callback
checkpoint = tensorflow.keras.callbacks.ModelCheckpoint(filepath=filepath,
monitor='val_acc', verbose=1, mode='auto')
bestobj= bestval()
拟合模型:
model.fit(xtr,ytr, epochs=4, validation_data=(xte,yte), batch_size=128, callbacks=[bestobj])
当我运行上面我得到以下错误:
ValueError:具有多个元素的数组的真值不明确。使用 a.any() 或 a.all()
我知道我在做一些愚蠢的事情,但我不知道如何解决。任何帮助,将不胜感激。
解决方案
我猜错误在下一行,您正在尝试将值与列表进行比较。
if logs.get('val_acc')> [i for i in self.history['val_acc']]:
尝试,
for i in self.history['val_acc']:
if logs.get('val_acc')>i:
#your code
推荐阅读
- php - 如何在 MediaWiki 中创建新功能?- 不是扩展
- python - 如何合并两个 Word2Vec 文件
- python - scikit-learn 中 kmeans 的自定义标准
- css - 如何将关键帧动画保存为 APNG 或 GIF?
- sql - 获取所有超过 25 个字符的记录 (LENGTH(hcp.phone_number) > 25)
- python - 在我的 virtualenv 中安装期货包后出现语法错误
- excel - 如何修复以下 VBA 代码上的错误 400
- python - 使用 Python 在 Excel 中重命名工作表;引用工作表名称的单元格值
- bash - 语法错误:“(”使用 bash 命令创建 openssl 证书时出现意外错误
- apache-spark - Spark-redis:数据帧写入时间太慢