python - CNN keras手写识别准确率高但预测差
问题描述
我基本上是为一个学校项目做这个,并按照一些指南使用 CNN 制作神经元网络。我使用的库是 cv2、NumPy、TensorFlow 和 matplotlib。我目前面临的问题是我的网络具有很高的准确性但预测非常糟糕。我确保图片是倒置的并且是 28x28。我还将要预测的图像数量从 5 个扩展到 10 个。我还尝试添加更多层,但也没有帮助。如果有人可以帮助我,那就太棒了!我对此也很陌生,所以请尽你所能解释!
输出示例:如您所见,笔迹还不错,但仍然无法预测它是 6 而是 1。
这是代码:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(a_train, b_train), (a_test, b_test) = mnist.load_data()
a_train = tf.keras.utils.normalize(a_train, axis=1)
a_test = tf.keras.utils.normalize(a_test, axis=1)
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))
model.add(tf.keras.layers.Dense(units=255, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(units=255, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(units=20, activation=tf.nn.softmax))
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(a_train, b_train, epochs=50)
lost, accuracy = model.evaluate(a_train, b_train)
print(lost)
print(accuracy)
model.save('test.model')
for x in range(1,11):
img = cv2.imread(fr'C:\Users\Eric\PycharmProjects\pythonProject2\test.model\{x}.png')[:,:,0]
img = np.invert(np.array([img]))
prediction = model.predict(img)
print(f'My Guess is: {np.argmax(prediction)}')
plt.imshow(img[0], cmap=plt.cm.binary)
plt.show()
我尝试做的事情: 我尝试添加更多层,假设它会训练并有更好的预测。我添加了更多的样本数量,看看我是否可以有更高的预测。我从 5 上升到 10,但仍然是 20% 的正确预测。我尝试过更改 Epoch 并尝试了更多的批量大小,但也没有用。
我几乎被困在这一点上,尽我所能去理解它,但根本无法改进它。如果有人有任何提示,请告诉我!
解决方案
预测时需要对图像进行标准化。cv2.imread
创建一个从 0 到 255 的数组。您可以通过除以对其进行归一img
化255.
您用来预测的图像也应该在黑色背景上有白色文本。
最后,您不需要np.invert
.
所以你的代码应该是
for x in range(1, 11):
img = np.expand_dims(cv2.imread(f'C:\Users\Eric\PycharmProjects\pythonProject2\test.model\{x}.png')[:, :, 0], 0) / 255.
prediction = model.predict(img)
print(f'My Guess is: {np.argmax(prediction)}')
plt.imshow(img[0], cmap=plt.cm.binary)
plt.show()
推荐阅读
- opencv - 如何提高 websocket 接收数据的速度
- sql - 在 Hive SQL 中提取具有特定模式的子字符串
- c++ - 我是否错误地实现了 strcpy_s?
- react-native - React-Native & async 函数:设置条件值
- python - 为什么我只有一个测试用例错误,而其他所有测试用例都正确?
- c++ - 为什么在 C++20 中 unique_ptr 不是 equal_comparable_with nullptr_t?
- c# - 如何转换列表
剑道网格的 JSON - wpf - WPF SelectedItem 绑定到源不影响键盘导航
- c# - Winform需要顶部和底部PictureBox的碰撞
- angular - 错误 TS2339:“可观察”类型上不存在属性“过滤器”
'。在角 rxjs