python - 比较 tf.keras.callbacks.Callback 中回调实例中单个类的精度
问题描述
我需要帮助来设计我的回调,我有以下架构:
def CNN_exctractor(input_img):
l2_loss_lambda = 0.01 # the definintion of l2 regaluraiation
l2 = None if l2_loss_lambda is None else regularizers.l2(l2_loss_lambda)
if l2 is not None:
print('Using L2 regularization - l2_loss_lambda = %.7f' % l2_loss_lambda)
conv1 = Conv2D(filters=64, kernel_size=(3,3), padding="same", activation="relu", kernel_regularizer=l2)(input_img)
conv11 = BatchNormalization()(conv1)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv11)
conv10 = Conv2D(filters=64, kernel_size=(3,3), padding="same", activation="relu", kernel_regularizer=l2)(pool2)
conv110 = BatchNormalization()(conv10)
pool21 = MaxPooling2D(pool_size=(2, 2))(conv110)
conv3 = Conv2D(filters=128, kernel_size=(3,3), padding="same", activation="relu", kernel_regularizer=l2)( pool21)#conv21)
conv31 = BatchNormalization()(conv3)
conv5 = Conv2D(filters=256, kernel_size=(3,3), padding="same", activation="relu", kernel_regularizer=l2)( conv31)#conv41)
conv51 = BatchNormalization()(conv5)
conv511 = Conv2D(filters=256, kernel_size=(3,3), padding="same", activation="relu", kernel_regularizer=l2)( conv51)#conv41)
conv5111 = BatchNormalization()(conv511)
#pool3 = MaxPooling2D(pool_size=(2, 2))(conv51)
return conv5111
def fc1(enco):
l2_loss_lambda = 0.01
l2 = None if l2_loss_lambda is None else regularizers.l2(l2_loss_lambda)
if l2 is not None:
print('Using L2 regularization - l2_loss_lambda = %.7f' % l2_loss_lambda)
flat = Flatten()(enco)
den = Dense(256, activation='relu',kernel_regularizer=l2)(flat)#(den_n)#(den_n)
den_n= BatchNormalization()(den)
den1 = Dense(128, activation='relu',kernel_regularizer=l2)(den_n)#(den_n)#(den_n)
den1_n= BatchNormalization()(den1)
out = Dense(2, activation='softmax')(den1_n)
return out
如您所见,我在输出端有两个神经元,我使用这个简单的代码进行回调:
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if((logs.get('val_accuracy')>= 0.92) and (logs.get('accuracy')>= 0.96) and ):
print("\nReached %2.2f%% accuracy, so stopping training!!" %(0.96*100))
self.model.stop_training = True
我正在比较训练和验证的准确性,我想做的不是比较整个验证的准确性,而是比较单个类的精度,例如(如果存在)
logs.get('class_1_precision')>= 0.8
解决方案
您可以将验证数据传递给回调,然后针对特定类对其进行过滤。我不知道您是如何构建验证数据的,但在这里我假设它分为两组(val_x
和val_y
)。在回调中,您将获得包含您需要的类的行(也许过滤您需要val_y
的类的索引,然后获取相同的索引val_x
) - 我已经把这一点留给您了。
from sklearn.metrics import precision_score
class myCallback(tf.keras.callbacks.Callback):
def __init__(self, val_x, val_y):
super(myCallback, self).__init__()
self.val_x = val_x
self.val_y = val_y
def on_epoch_end(self, epoch, logs={}):
# Filter validation data for your required class
val_x_class_1 = #filter self.val_x for your class
val_y_class_1 = #filter self.val_y for your class>
# Get predictions for the filtered val data
class1_scores = self.model.predict(val_x_class_1)
# Get indices of best predictions - you might need to alter this
y_pred_class1 = tf.argmax(class1_scores, axis=1)
y_true_class1 = tf.argmax(val_y_class_1, axis=1)
# Calculate precision
precision_class1 = precision_score(y_true_class1, y_pred_class1)
# Rest of your code
<....>
要将验证数据传递给回调,您需要在 fit 函数中添加如下内容:
cbs = myCallback(val_x,val_y)
model.fit(...., callbacks=[cbs])
推荐阅读
- python-3.x - 如何解决Netlogo中使用python扩展的错误?
- python-3.x - 如何从 PYTHONPATH 上的目录中正确导入模块?
- javascript - 在 React Native 上使用 Ajax
- javascript - module.exports 的含义是什么?为什么叫模块?
- pandas - 如何将 numpy 数组附加到 pandas 数据帧
- ios - SwiftUI - 带有可点击按钮的单元格(在表单内)
- javascript - 如何在鼠标悬停 d3.js 上定位图像
- laravel - 如何计算列
- c++ - VSCode C++ 编译错误:如何链接附加库 (*.lib)?
- python - Python pandas读取文件,写入excel