python - 自定义 keras 指标返回 nan
问题描述
我想在分类问题中进一步将输出类分成更少的桶。我有 4 个输出类(即 0、1、2、3)。但在训练期间,我还想跟踪 2 个类的准确性:
- 将 0 和 1 视为 0 类
- 将 2 和 3 视为 1 类
为此,我创建了一个新指标并使用模型对其进行编译:
def new_classes_acc(y_true, y_pred):
actual = tf.floor( y_true / 2 )
predicted = tf.floor( y_pred / 2 )
return K.categorical_crossentropy(actual, predicted)
像这样编译它:
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy', new_classes_acc])
我得到nan
准确度值。这样做的正确方法是什么?
解决方案
由于有 4 个类,并且您已将其设置categorical_crossentropy
为损失,因此标签是 one-hot 编码的并且将是 shape (n_samples, 4)
。因此,首先您需要使用 function 找到真实和预测的类argmax
,然后使用floor
function(此外,您希望创建一个度量而不是损失函数;因此您不应该使用K.categorical_crossentropy
):
from keras import backend as K
import tensorflow as tf
def custom_metric(y_true, y_pred):
tr = tf.floor(K.argmax(y_true, axis=-1) / 2)
pr = tf.floor(K.argmax(y_pred, axis=-1) / 2)
return K.cast(K.equal(tr, pr), K.floatx())
现在,让我们测试一下。首先我们创建一个简单的模型并编译它:
model = Sequential()
model.add(Dense(4, activation='softmax', input_shape=(2,)))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy', custom_metric])
然后我们创建虚拟数据:
import numpy as np
data = np.array([1, 2]).reshape(1, 2)
并使用我们的模型来预测给定数据的标签:
print(model.predict(data))
# prints: [0.04662106, 0.8046941 , 0.07660434, 0.0720804 ]
所以第二类的概率最高,将是预测的标签。现在,考虑到我们定义的自定义指标,给定一个[1, 0, 0, 0]
或[0, 1, 0, 0]
作为真实标签,自定义指标应该给我们 1(即 100%)。让我们确认一下:
true_labels = np.array([1, 0, 0, 0]).reshape(1,4)
print(model.evaluate(data, true_labels)) # gives: [3.0657029151916504, 0.0, 1.0]
返回列表的第一个元素对应于损失,第二个元素对应于accuracy
我们的自定义指标,第三个元素对应于我们的自定义指标。如您所见,准确度为零(因为真实类是第一类,但预测类是第二类),自定义指标为 1,正如预期的那样。
另一种情况:
true_labels = np.array([0, 1, 0, 0]).reshape(1,4)
print(model.evaluate(data, true_labels)) # gives: [0.21729297935962677, 1.0, 1.0]
这里的准确率是一(因为真实类和预测类都是二类),自定义指标也是一。对于剩下的两种情况,您可以进一步确认这一点[0, 0, 1, 0]
和[0, 0, 0, 1]
作为真正的标签;对于自定义指标的值,两者都应返回零。
奖励:如果标签是稀疏的,即 0、1、2 和 3,怎么办?然后,您可以使用keras.np_utils.to_categorical()
方法对它们进行一次热编码,然后使用上面定义的自定义指标。
推荐阅读
- angular - NativeScript 刷新后拉到 ActionBar 后面的刷新页面 | nativescript-pulltorefresh
- r - R闪亮:可以通过条件面板选择输入
- c# - 使用一个登录表单进行管理员和用户登录
- code-push - code-push:如何查找从 appcenter 网站添加的应用程序的部署密钥(添加新应用程序选项)
- vb.net - VB.net Winform 自动完成匹配字符串的任何部分
- solr - 如何使用隐式路由设置 solr 集群(v4.9.1)
- push-notification - 如何在 FCM UI 中定位特定受众?
- php - MYSQL - 两个表从每个用户的第二个表最后跟踪的行中获取记录
- javascript - React Native - 使用 NativeBase Drawer 时无法读取未定义的属性“_root”
- css - 3列布局使用float,第三列的元素在中间列之上