python - keras 的自定义指标 auc 回调问题
问题描述
我正在尝试roccallback
为 keras 实现自定义。我写的回调函数如下。我必须达到 0.90 的目标。
AUROC 回调:
class rocback(Callback):
def __init__(self, validation_data):
super(rocback, self).__init__()
# self.training_data = training_data
self.validation_data = validation_data
def on_train_begin(self, logs={}) :
return
def on_epoch_end(self, epoch, logs={}):
probs = self.model.predict(self.validation_data[0])
probs = np.round(probs)
y_true = self.validation_data[1]
y_true = np.round(y_true)
score = roc_auc_score(y_true, probs, average='micro')
logs['auc'] = score
因此,我写的下一个回调是为了实现目标。
class scoreTarget(Callback):
def __init__(self, target):
super(scoreTarget, self).__init__()
self.target = target
def on_epoch_end(self, epoch, logs={}):
acc = logs['auc']
if acc >= self.target:
self.model.stop_training = True
使用的回调列表如下:
roc_callback = rocback((X_test_pooled_output, y_test))
early_stopping = EarlyStopping(patience=5)
tensorboard = TensorBoard()
reduce_lr = ReduceLROnPlateau(patience=3)
target = scoreTarget(0.90)
callbacks = [
roc_callback,
early_stopping,
tensorboard,
reduce_lr,
target,
]
我写的分类器函数如下:
class ReviewClassifier(Model):
def __init__(self):
super(ReviewClassifier, self).__init__()
self.dense_1 = Dense(64, activation='relu')
self.dense_2 = Dense(32, activation='relu')
self.dense_3 = Dense(16, activation='relu')
self.classify = Dense(1, activation='sigmoid')
self.dropout_1 = Dropout(0.2)
self.dropout_2 = Dropout(0.2)
self.dropout_3 = Dropout(0.2)
def call(self, inputs):
x = self.dense_1(inputs)
x = self.dropout_1(x)
x = self.dense_2(x)
x = self.dropout_2(x)
x = self.dense_3(x)
x = self.dropout_3(x)
x = self.classify(x)
return x
review_classifier = ReviewClassifier()
review_classifier.build((None, 768))
review_classifier.summary()
我写的编译函数是这样的:
review_classifier.compile(loss='binary_crossentropy',
optimizer='adam',metrics=[rocback])
拟合函数:
!rm -rf ./logs/*
history = review_classifier.fit(X_train_pooled_output, y_train,
batch_size=32, epochs=100,
callbacks=callbacks,
validation_data=(X_test_pooled_output, y_test))
收到的错误是:
Epoch 1/100
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-54-fd7595a2c88c> in <module>()
1 get_ipython().system('rm -rf ./logs/*')
----> 2 history = review_classifier.fit(X_train_pooled_output, y_train, batch_size=32, epochs=100,callbacks=callbacks, validation_data=(X_test_pooled_output, y_test))
9 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
975 except Exception as e: # pylint:disable=broad-except
976 if hasattr(e, "ag_error_metadata"):
--> 977 raise e.ag_error_metadata.to_exception(e)
978 else:
979 raise
TypeError: in user code:
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:805 train_function *
return step_function(self, iterator)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:795 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
/usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:1259 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:2730 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:3417 _call_for_each_replica
return fn(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:788 run_step **
outputs = model.train_step(data)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:758 train_step
self.compiled_metrics.update_state(y, y_pred, sample_weight)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/compile_utils.py:408 update_state
metric_obj.update_state(y_t, y_p, sample_weight=mask)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/utils/metrics_utils.py:90 decorated
update_op = update_state_fn(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/metrics.py:177 update_state_fn
return ag_update_state(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/metrics.py:618 update_state **
matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
TypeError: __init__() takes 2 positional arguments but 3 were given
知道我在为 roc 创建的自定义回调中做错了什么吗?
让我知道您还需要什么输入。
解决方案
正如我们在评论中提到的,您可以在编译模型时使用内置的tf.keras.metrics.AUC,简单-快速-高效。但使用sklearn.metrics.roc_auc_score也没有问题。但是在回调中使用它可能会减慢您的训练时间。尝试这个:
class ROAUCMetrics(tf.keras.callbacks.Callback):
def __init__(self, val_data):
super().__init__()
self.valid_x = val_data[0]
self.valid_y = val_data[1]
def on_train_begin(self, logs={}):
self.val_aucs = []
def on_epoch_end(self, epoch, logs={}):
pred = self.model.predict(self.valid_x)
val_auc = roc_auc_score(self.valid_y, pred, average='micro')
print('\nval-roc-auc: %s' % (str(round(val_auc,4))),end=100*' '+'\n')
self.val_aucs.append(val_auc)
return
# sklearn auc
roc = ROAUCMetrics(val_data=(x_val, y_val))
# tf.keras auc
model.compile(.., ..., metrics=["AUC"])
# running
model.fit(x_train, y_train, batch_size=1024,
epochs=..., callbacks=[roc],
validation_data=(x_val, y_val))
# get the values of auc, computed using sklearn auc
roc.val_aucs
但是,请注意两者计算AUC的方式不同,一个使用 Approximate AUC,另一个使用 Riemann sum,我在一个示例中进行了测试,它们非常具有可比性,但有时没有。
根据您的第一条评论,您的设置应如下所示:
# (1)
# compile with no metrics - as we have custom callback metric to use
review_classifier.compile(loss='binary_crossentropy',optimizer='adam')
# (2)
# or,
# we can add another metrics e.g 'accuracy' or whatever
# here we use built-in AUC
review_classifier.compile(loss='binary_crossentropy',
optimizer='adam',metrics=['AUC'])
# sklearn auc
roc = ROAUCMetrics(val_data=(x_val, y_val))
early_stopping = EarlyStopping(patience=5)
tensorboard = TensorBoard()
reduce_lr = ReduceLROnPlateau(patience=3)
target = scoreTarget(0.90)
callbacks = [
roc,
early_stopping,
tensorboard,
reduce_lr,
target,
]
# fitting
model.fit(x_train, y_train, batch_size=1024,
epochs=..., callbacks=callbacks,
validation_data=(x_val, y_val))
推荐阅读
- date - Windows批处理文件检查时间和日期然后重命名
- node.js - 确定最佳变量以获得最高回报的最快方法
- python-3.x - Python 3 中 bisect_left 中的索引超出范围
- javascript - Javascript post请求函数返回未定义
- mysql - PHP CodeIgniter 和 JQuery AJAX 使用 jquery 追加表行并将所有表数据插入数据库
- spring-boot - SpringBoot RedirectAttributes 未显示在百里香叶中
- r - R中的闰年一些输出是错误的
- asp.net-core - .net 核心 webapp 和控制台应用程序是否有一致的方法来确定数据文件文件夹的路径?
- php - php-将记录插入数据库
- .net - 错误:SMTP 服务器需要安全连接或客户端未通过身份验证。服务器响应为:5.7.0 Authentication Required