python - TF 2.0 中使用 Keras 的自定义损失函数
问题描述
TF 1.13 的 Tensorflow 2.0 Alpha 中的自定义损失函数
我正在尝试在TF 2.0中使用此库中的 roc_auc 损失函数。model.compile()
虽然我的实现没有错误,但损失和准确性不会移动。
我首先使用 Google 建议的代码将 1.0 TF 代码转换为 2.0。
然后我从库中导入函数并以下列方式使用:
model.compile(optimizer='adam',
loss=roc_auc_loss,
metrics=['accuracy',acc0, acc1, acc2, acc3, acc4])
Epoch 17/100
100/100 [==============================] - 20s 197ms/step - loss: 469.7043 - accuracy: 0.0000e+00 - acc0: 0.0000e+00 - acc1: 0.0000e+00 - acc2: 0.0000e+00 - acc3: 0.0000e+00 - acc4: 0.0000e+00 - val_loss: 152.2152 - val_accuracy: 0.0000e+00 - val_acc0: 0.0000e+00 - val_acc1: 0.0000e+00 - val_acc2: 0.0000e+00 - val_acc3: 0.0000e+00 - val_acc4: 0.0000e+00
Epoch 18/100
100/100 [==============================] - 20s 198ms/step - loss: 472.0472 - accuracy: 0.0000e+00 - acc0: 0.0000e+00 - acc1: 0.0000e+00 - acc2: 0.0000e+00 - acc3: 0.0000e+00 - acc4: 0.0000e+00 - val_loss: 152.2152 - val_accuracy: 0.0000e+00 - val_acc0: 0.0000e+00 - val_acc1: 0.0000e+00 - val_acc2: 0.0000e+00 - val_acc3: 0.0000e+00 - val_acc4: 0.0000e+00
Epoch 19/100
78/100 [======================>.......] - ETA: 4s - loss: 467.4657 - accuracy: 0.0000e+00 - acc0: 0.0000e+00 - acc1: 0.0000e+00 - acc2: 0.0000e+00 - acc3: 0.0000e+00 - acc4: 0.0000e+00
我想了解 TF 2.0 中的 Keras 有什么问题,它显然不是反向传播。谢谢。
解决方案
@ruben你能分享一个独立的代码来重现这个问题吗?我认为我们需要检查函数定义。您是否在函数定义之上添加了 @tf.function() ?谢谢!
请检查以下示例(来自 TF 网站的简单示例)
!pip install tensorflow==2.0.0-beta1
import tensorflow as tf
from tensorflow import keras
import keras.backend as K
# load mnist data
mnist=tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()
x_train,x_test=x_train/255.0,x_test/255.0
# Custom Metric1 (for example)
@tf.function()
def customMetric1(yTrue,yPred):
return tf.reduce_mean(yTrue-yPred)
# Custom Metric2 (for example)
@tf.function()
def customMetric2(yTrue, yPred):
return tf.reduce_mean(tf.square(tf.subtract(yTrue,yPred)))
model=tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28,28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10,activation='softmax')
])
# Compile the model with custom loss functions
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy', customMetric1, customMetric2])
# Fit and evaluate model
model.fit(x_train,y_train,epochs=5)
model.evaluate(x_test,y_test)
输出
警告:标志解析前的日志记录进入标准错误。W0711 23:57:16.453042 139687207184256 deprecation.py:323] 来自 /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_grad.py:1250:add_dispatch_support..wrapper(来自 tensorflow.python .ops.array_ops) 已弃用,将在未来版本中删除。更新说明:使用 2.0 中的 tf.where,与 np.where 具有相同的广播规则 Train on 60000 samples
纪元 1/5 60000/60000 [===============================] - 5s 87us/样本 - 损失:0.2983 - 准确度: 0.9133 - customMetric1: 4.3539 - customMetric2: 27.3769 纪元 2/5 60000/60000 [=============================] - 5s 83us/样本 - 损失:0.1456 - 准确度:0.9555 - customMetric1:4.3539 - customMetric2:27.3860
时期 3/5 60000/60000 [===============================] - 5s 82us/样本 - 损失:0.1095 - 准确度:0.9663 - customMetric1:4.3539 - customMetric2:27.3881
时代 4/5 60000/60000 [===============================] - 5s 83us/样本 - 损失:0.0891 - 准确度:0.9717 - customMetric1:4.3539 - customMetric2:27.3893
纪元 5/5 60000/60000 [===============================] - 5s 87us/样本 - 损失:0.0745 - 准确度:0.9765 - customMetric1:4.3539 - customMetric2:27.3901
10000/10000 [===============================] - 0s 46us/样本 - 损失:0.0764 - 准确度:0.9775 - customMetric1 : 4.3429 - customMetric2: 27.3301 [0.07644735965565778, 0.9775, 4.342905, 27.330126]
编辑 1
例如,如果您想customMetric1
用作自定义损失函数,则如下更改几项并运行代码。希望这可以帮助。
# Custom Metric1 (for example)
@tf.function()
def customMetric1(yTrue,yPred,name='CustomLoss1'):
yTrue = tf.dtypes.cast(yTrue, tf.float32)
return tf.reduce_mean(yTrue-yPred)
model.compile(optimizer='adam',loss=customMetric1, metrics=['accuracy', customMetric2])
推荐阅读
- android - 从无名称的嵌套数组 JSON 中获取数据 Android
- azure - Azure Blob 存储中托管的静态网站的 CSP 标头
- python - 在 Juptyer 实验室/笔记本中突出显示检查函数的源代码
- express - 为什么我的应用程序在 5-6 分钟后注销,尽管我已将 max Age 设置为 3 小时?
- reactjs - 每次呈现 null 的组件都会触发 componentDidUpdate 吗?
- laravel - 无法使用带有 imagick 驱动程序的干预库根据需要呈现印地语字体
- html - 如何为 html div 元素创建内部 mediawiki 链接
- vue.js - 为什么 Vue 和我的 DOM Textarea 不同步(不是反应性问题)?
- java - 使用 java Stream 转换一些逻辑
- javascript - 制作和处理自定义电子对话框的最佳方法?