python - Tensorflow 2.0:模型检查点的自定义指标(平衡准确度得分)不起作用
问题描述
我想实现一个基于平衡准确度得分的模型检查点回调。为此,我实现了以下类:
class BalAccScore(keras.callbacks.Callback):
def __init__(self, validation_data=None):
super(BalAccScore, self).__init__()
self.validation_data = validation_data
def on_train_begin(self, logs={}):
self.balanced_accuracy = []
def on_epoch_end(self, epoch, logs={}):
y_predict = tf.argmax(self.model.predict(self.validation_data[0]), axis=1)
y_true = tf.argmax(self.validation_data[1], axis=1)
balacc = balanced_accuracy_score(y_true, y_predict)
self.balanced_accuracy.append(round(balacc,6))
logs["val_bal_acc"] = balacc
keys = list(logs.keys())
print("\n ------ validation balanced accuracy score: %f ------\n" %balacc)
然后我定义以下回调
balAccScore = BalAccScore(validation_data=(X_2, y_2))
mc = ModelCheckpoint(filepath=callback_path, monitor="val_bal_acc", verbose=1, save_best_only=True, save_freq='epoch')
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=['val_bal_acc'])
history = model.fit(X_1, y_1, epochs = 5, batch_size = 512,
callbacks=[balAccScore, mc],
validation_data = (X_2, y_2)
)
然后我得到错误
ValueError: Unknown metric function: val_bal_acc
尽管我在使用例如准确性时发现它在历史记录中,即通过在编译时设置metrics = [“acc”]。在这种情况下,我会收到预期的警告:
WARNING:tensorflow:Can save best model only with val_bal_acc available, skipping.
但除此之外,模型运行完美。不知道为什么它不运行。
解决方案
您应该只删除 compile 中的引号:
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=[val_bal_acc])
或者至少这是它在 R 中的工作方式
推荐阅读
- unit-testing - 无法在厨师食谱中运行 rspec 单元测试“touch_file”
- ios - Swift iOS - 选择新单元格时更改旧单元格的背景颜色
- android - React Native 未在发行版中编译
- android - NoClassDefFoundError:解析失败:Lcom/google/android/gms/common/internal/zzx
- php - WooCommerce 迷你购物车事件触发器
- python - 学习 Python 和 Scrapy
- node.js - HTTPError:在shopify中将结帐标记为已完成时的响应代码422(无法处理的实体)
- python - django 分页不适用于 django 过滤器
- php - 带有字段 concat 的语句 WHERE
- vb.net - 如何将图像从 Form1 分配到 Form2