python - 检查输入时出错:预期 conv2d_3_input 的形状为 (64, 64, 3) 但得到的数组的形状为 (64, 64, 1)
问题描述
基本上,我在学习 udemy 的机器学习 AZ 教程,在那里我学会了训练我自己的卷积神经网络模型,但是没有向我展示如何使用该训练过的模型——如何提供输入并获得输出。我在 youtube 上查看了一些教程,以通过 Keras 中的 cnn 传递单个图像,并阅读了本教程。当我运行代码时出现错误
ValueError:检查输入时出错:预期 conv2d_3_input 的形状为 (64, 64, 3) 但得到的数组的形状为 (64, 64, 1)
现在我完全不知道如何处理这个问题。如果有人能告诉我在我训练了我的网络之后该怎么做以及如何解决整个问题,我将非常感激我尝试改变 img_array = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
阅读颜色,但我认为这与此无关
我用来训练卷积网络的代码:
import keras
from keras.models import Sequential
from keras.layers import Convolution2D
from keras.layers import MaxPooling2D
from keras.layers import Flatten
from keras.layers import Dense
初始化 cnn
classifier = Sequential()
step1-应用卷积
classifier.add(Convolution2D(32,3,3, border_mode = 'same', input_shape=(64,64,3), activation = 'relu'))
应用最大池化
classifier.add(MaxPooling2D(pool_size=(2,2)))
添加另一个卷积层以获得更好的准确性
classifier.add(Convolution2D(32,3,3, border_mode = 'same', activation = 'relu'))
classifier.add(MaxPooling2D(pool_size=(2,2)))
应用展平
classifier.add(Flatten())
第 4 步 - 完全连接
classifier.add(Dense(output_dim = 128, activation = 'relu'))
#input
classifier.add(Dense(output_dim = 1, activation = 'sigmoid'))
#output
全部编译
classifier.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])
#fitting the cnn to the images
from keras.preprocessing.image import ImageDataGenerator
#pixels take values between 0 & 255
train_datagen = ImageDataGenerator(rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
training_set = train_datagen.flow_from_directory('dataset/training_set', target_size=(64, 64), batch_size=32, class_mode='binary')
test_set = test_datagen.flow_from_directory('dataset/test_set',
target_size=(64, 64),
batch_size=32,
class_mode='binary')
classifier.fit_generator(training_set,
samples_per_epoch = 8000,
nb_epoch = 25,
verbose = 1,
validation_data = test_set,
nb_val_samples = 2000)
classifier.save('my_model.h5')
我在新文件中使用的代码来试用我训练有素的网络
import cv2
import tensorflow as tf
import keras
from keras.models import Sequential
Categories = ["Dogs","cats"]
def prepare(filepath):
img_size = 64
img_array = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
new_array = cv2.resize(img_array, (img_size,img_size))
return new_array.reshape(-1,img_size,img_size, 1)
classifier = keras.models.load_model('my_model.h5')
prediction = classifier.predict([prepare('/Users/m.zain/Documents/machine learning-2019 dps/Machine Learning A-Z Template Folder/Part 8 - Deep Learning/Section 40 - Convolutional Neural Networks (CNN)/dog.jpg')])
print(prediction)
prediction = classifier.predict([prepare('/Users/m.zain/Documents/machine learning-2019 dps/Machine Learning A-Z Template Folder/Part 8 - Deep Learning/Section 40 - Convolutional Neural Networks (CNN)/dog.jpg')])
print(Categories[int(prediction[0][0])])
prediction = classifier.predict([prepare('/Users/m.zain/Documents/machine learning-2019 dps/Machine Learning A-Z Template Folder/Part 8 - Deep Learning/Section 40 - Convolutional Neural Networks (CNN)/cat.jpg')])
print(Categories[int(prediction[0][0])])
当我执行代码时,我希望得到“狗”作为输出:
prediction = classifier.predict([prepare('/Users/m.zain/Documents/machine learning-2019 dps/Machine Learning A-Z Template Folder/Part 8 - Deep Learning/Section 40 - Convolutional Neural Networks (CNN)/dog.jpg')])
print(Categories[int(prediction[0][0])])
当我运行此命令时,将 cat 作为输出:
prediction = classifier.predict([prepare('/Users/m.zain/Documents/machine learning-2019 dps/Machine Learning A-Z Template Folder/Part 8 - Deep Learning/Section 40 - Convolutional Neural Networks (CNN)/cat.jpg')])
print(Categories[int(prediction[0][0])])
解决方案
推荐阅读
- apache-nifi - Nifi从csv列值获取文件名
- fabricjs - 在 viewportTransform 可视化之外隐藏对象
- angular - 将 ngModel 放在输入标签中会停止项目在浏览器中显示
- oracle - 如何使用 oracle 在 xpath 查询中声明数组
- css - 如何在使用 ReactComponent 作为徽标导入的 svg 中选择特定部分?
- rest - 无法在 Jmeter 中获取访问令牌
- r - R- 使用 DBI dbExecute 或其他包向 mariadb/mySQL 插入左连接
- c++ - 模板默认参数在 C++ 中不起作用
- flutter - 如何从子小部件调用 setState*
- visual-c++ - 致命错误 LNK1104 无法在通用 Windows C++ 项目中使用 /Qspectre 和 Spectre 库 *installed* 打开文件 MSVCRT.lib