tensorflow - 基于 AUC 的提前停止
问题描述
我对 ML 相当陌生,目前正在使用 tensorflow 和 keras 在 python 中实现一个简单的 3D CNN。我想根据 AUC 进行优化,并且还想在 AUC 分数方面使用提前停止/保存最佳网络。我一直在使用 tensorflow 的 AUC 函数,如下所示,它在训练中效果很好。但是,未保存 hdf5 文件(尽管检查点 save_best_only=True),因此我无法获得评估的最佳权重。
以下是相关的代码行:
model.compile(loss='binary_crossentropy',
optimizer=keras.optimizers.Adam(lr=lr),
metrics=[tf.keras.metrics.AUC()])
model.load_weights(path_weights)
filepath = mypath
check = tf.keras.callbacks.ModelCheckpoint(filepath, monitor=tf.keras.metrics.AUC(), save_best_only=True,
mode='auto')
earlyStopping = tf.keras.callbacks.EarlyStopping(monitor=tf.keras.metrics.AUC(), patience=hyperparams['pat'],mode='auto')
history = model.fit(X_trn, y_trn,
batch_size=bs,
epochs=n_epochs,
verbose=1,
callbacks=[check, earlyStopping],
validation_data=(X_val, y_val),
shuffle=True)
有趣的是,如果我只在早期停止和检查点更改 monitor='val_loss'(而不是 model.compile 中的 'metrics'),hdf5 文件会被保存,但显然在验证损失方面给出了最好的结果。我也尝试过使用 mode='max' 但问题是一样的。我非常感谢您的建议,或任何其他建设性的想法如何解决这个问题。
解决方案
事实证明,即使你添加了一个非关键字指标,当你想要监控它时,你仍然需要使用它的句柄来引用它。在您的情况下,您可以这样做:
auc = tf.keras.metrics.AUC() # instantiate it here to have a shorter handle
model.compile(loss='binary_crossentropy',
optimizer=keras.optimizers.Adam(lr=lr),
metrics=[auc])
...
check = tf.keras.callbacks.ModelCheckpoint(filepath,
monitor='auc', # even use the generated handle for monitoring the training AUC
save_best_only=True,
mode='max') # determine better models according to "max" AUC.
如果您想监控验证 AUC(这更有意义),只需val_
在句柄的开头添加:
check = tf.keras.callbacks.ModelCheckpoint(filepath,
monitor='val_auc', # validation AUC
save_best_only=True,
mode='max')
另一个问题是您的 ModelCheckpoint 正在根据最小AUC 而不是您想要的最大值来保存权重。
这可以通过设置来改变mode='max'
。
做什么mode='auto'
?
此设置实质上检查监视器的参数是否包含'acc'
并将其设置为最大值。在任何其他情况下,它会设置 uses mode='min'
,这就是您的情况。
你可以在这里确认
推荐阅读
- r - R,在使用 order 函数时获得一元运算符的无效参数
- python-3.x - 我在带有 PyCharm 和 Python 3.7.2 的 Windows 10 64 位操作系统上安装了 tensorflow 和 tensorflow_gpu。但是我在使用 tensorflow 时遇到了这些错误
- json - 如何使用 Swift 4.2 正确编码 JSON 响应?
- linux - 如果套接字连接到其主机的IP地址,linux内核是否会优化包传输?
- jquery - 按下键时无法禁用 Windows 键?
- java - 请让我知道如何在 java 的构造函数中将值作为 3 级传递?
- javascript - npm WARN 已弃用 boom@2.10.1:不再维护此版本。请升级到最新版本
- mongodb - 在 MongoDB 中分组并获取每个组的最新日期的所有匹配文档
- android - 如何在 BottomNavigationView 中设置指标?
- python - 如何在 django-graphql-jwt 中更改默认的“用户名”