python - 手动计算分类准确度与 keras 给出的不匹配
问题描述
我在解释 Kerasmodel.fit()
方法的输出时遇到了麻烦。
那个设定
print(tf.version.VERSION) # 2.3.0
print(keras.__version__) # 2.4.0
我有一个用于 3 类分类问题的简单前馈模型:
def get_baseline_mlp(signal_length):
input_tensor = keras.layers.Input(signal_length, name="input")
dense_1 = keras.layers.Flatten()(input_tensor)
dense_1 = keras.layers.Dense(name='dense_1',activation='relu',units=500)(dense_1)
dense_1 = keras.layers.Dense(name='dense_2',activation='relu',units=500)(dense_1)
dense_1 = keras.layers.Dense(name='dense_3',activation='relu',units=500)(dense_1)
dense_1 = keras.layers.Dense(name='dense_4',activation='softmax',units=3, bias_initializer='zero')(dense_1)
model = keras.models.Model(inputs=input_tensor, outputs=[dense_1])
model.summary()
return model
我的训练数据是单变量时间序列,我的输出是长度为 3 的单热编码向量(我的分类问题中有 3 个类)
模型编译如下:
mlp_base.compile(optimizer=optimizer,
loss='categorical_crossentropy',
metrics=['categorical_accuracy'])
我有一个功能可以通过两种方法手动计算我的预测的准确性:
def get_accuracy(model, true_x, true_y):
res = model.predict(true_x)
res = np.rint(res)
right = 0
for i in range(len(true_y[:, 0])):
if np.array_equal(res[i, :], true_y[i, :]):
#print(res[i,:], tr_y[i,:])
right += 1
else:
pass
tot = len(true_y[:,0])
print('True - total', right, tot)
print('acc: {}'.format((right/tot)))
print()
print(' method 2 - categorical')
res = model.predict(true_x)
res = np.argmax(res, axis=-1)
true_y = np.argmax(true_y, axis=-1)
right = 0
for i in range(len(true_y)):
if res[i] == true_y[i]:
right += 1
else:
pass
tot = len(true_y)
print('True - total', right, tot)
print('acc: {}'.format((right/tot)))
问题
在训练结束时,输出的分类准确度与我使用自定义函数得到的不匹配。
训练输出:
Model: "functional_17"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input (InputLayer) [(None, 9000)] 0
_________________________________________________________________
flatten_8 (Flatten) (None, 9000) 0
_________________________________________________________________
dense_1 (Dense) (None, 500) 4500500
_________________________________________________________________
dense_2 (Dense) (None, 500) 250500
_________________________________________________________________
dense_3 (Dense) (None, 500) 250500
_________________________________________________________________
dense_4 (Dense) (None, 3) 1503
=================================================================
Total params: 5,003,003
Trainable params: 5,003,003
Non-trainable params: 0
-------------------------------------------------------------------
Fit model on training data
Epoch 1/2
20/20 [==] - 0s 14ms/step - loss: 1.3796 categorical_accuracy: 0.3250 - val_loss: 0.9240 -
Epoch 2/2
20/20 [==] - 0s 8ms/step - loss: 0.8131 categorical_accuracy: 0.6100 - val_loss: 1.2811
精度函数输出:
True / total 169 200
acc: 0.845
method 2
True / total 182 200
acc: 0.91
为什么我得到错误的结果?我的准确性实施错误吗?
更新
按照desertnaut 的建议更正设置仍然无效。
拟合输出:
Epoch 1/3
105/105 [===] - 1s 9ms/step - loss: 1.7666 - categorical_accuracy: 0.2980
Epoch 2/3
105/105 [===] - 1s 6ms/step - loss: 1.2380 - categorical_accuracy: 0.4432
Epoch 3/3
105/105 [===] - 1s 5ms/step - loss: 1.0318 - categorical_accuracy: 0.5989
如果我使用 keras 的分类准确度函数,我仍然会得到不同的结果。
cat_acc = keras.metrics.CategoricalAccuracy()
cat_acc.update_state(tr_y2, y_pred)
print(cat_acc.result().numpy()) # outputs : 0.7211079
有趣的是,如果我用上述方法计算验证准确度,我会得到一致的输出。
解决方案
不太确定您的准确度计算(似乎不必要的复杂,我们总是更喜欢向量计算而不是for
循环),但是您的代码有两个问题可能会影响结果(甚至使它们变得毫无意义)。
第一个问题是,由于您处于多类设置中,因此您应该使用loss='categorical_crossentropy'
, 而不是 'binary_crossentropy'
;来编译模型。在为什么 binary_crossentropy 和 categorical_crossentropy 对同一问题给出不同的性能?看看当你以这种方式混合损失和准确性时会发生什么(另外,'binary_accuracy'
这里绝对没有意义)。
第二个问题是您错误地使用activation='sigmoid'
了最后一层:由于您处于多类(而不是多标签)设置中,标签单热编码,因此最后一层中的激活应该是softmax
,而不是sigmoid
。
推荐阅读
- javascript - 防止使用 svg 进行亚像素渲染
- c++ - 如何在我的向量实现中编写 operator=(带有移动语义)和 shrink_to_fit 函数?
- json - 如何从量角器中的 JSON 对象中提取字符串
- typo3 - 单击项目隐藏图标不刷新屏幕
- reporting-services - 根据 SSRS 中的记录数调整表大小
- jhipster - jhipster 中 java 服务部署的上下文路径
- java - 在Java Spring 4中从应用程序A到应用程序B进行rest或soap api调用时将用户名存储在数据库中
- testing - 这个异常是什么意思?
- pandas - 如何使用共享字典在一个循环中重命名多个数据帧?
- augmented-reality - 如何将指示器或指针锚定到 Sceneform?