tensorflow - 在 Keras 中使用 Eban 等人的召回损失精度
问题描述
我想使用 keras测试非标准损失,例如https://arxiv.org/abs/1608.04802中描述的precision_at_recall_loss。
这些损失在此处实现:httpsloss_layers.py
: //github.com/tensorflow/models/tree/archive/research/global_objectivesutil.py
以下代码是使用 MNIST 数据集的演示。
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import loss_layers
import util
def precision_recall_auc_loss(y_true, y_pred):
y_true = keras.backend.reshape(y_true, (batch_size, 1))
y_pred = keras.backend.reshape(y_pred, (batch_size, 1))
util.get_num_labels = lambda labels : 1
return loss_layers.precision_recall_auc_loss(y_true, y_pred)[0]
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
input_shape = x_train.shape[1:]
num_classes = 10
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
model = keras.Sequential([
layers.Conv2D(32,kernel_size=(3,3),activation='relu',input_shape=input_shape), \
layers.MaxPooling2D(2,2), \
layers.Flatten(), \
layers.Dropout(0.25), \
layers.Dense(num_classes, activation="softmax")
])
model.summary()
batch_size = 30
epochs = 10
target_recall = 0.9
model.compile(loss=precision_recall_auc_loss,
optimizer=keras.optimizers.Adam(lr=0.001))
model.fit(x_train, y_train, batch_size=batch_size, \
epochs=epochs, validation_split=0.15)
模型编译并开始拟合。但是,我收到以下错误:
Train on 51000 samples, validate on 9000 samples
Epoch 1/10
FailedPreconditionError: Attempting to use uninitialized value precision_at_recall_1/lambdas
[[{{node precision_at_recall_1/lambdas/read}}]]
解决方案
推荐阅读
- ios - 位置标记不显示 Swift 4
- java - 如何停止所有 VBS 文字转语音?
- swift - TableViewCell 元素中的 UIMenuController
- android - 在 android studio 中编译失败
- linux - Unix 用户级别如何影响我对 docker 映像的可见性?
- android - 计算到特定时区午夜的秒数
- javascript - 在 try/catch 块中成功异步请求后执行语句
- c++ - 交错插入排序功能无法正确排序
- android - Expo with Android Studio Emulator 错误:“Error running adb: This computer is not authorized to debug the device”
- php - 文件夹中的 HTML 图像不再出现 (PHP)