python - 为什么我的 python 代码出现值错误
问题描述
使用过:tensorflow、keras、python、flask
Error: ValueError: Input 0 of layer conv2d_3 is incompatible with the layer: : expected min_ndim=4, found ndim=3. Full shape received: (32, 224, 1)
神经网络的结构:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
rescaling_1 (Rescaling) (None, 224, 224, 3) 0
_________________________________________________________________
conv2d (Conv2D) (None, 224, 224, 16) 448
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 112, 112, 16) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 112, 112, 32) 4640
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 56, 56, 32) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 56, 56, 64) 18496
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 28, 28, 64) 0
_________________________________________________________________
flatten (Flatten) (None, 50176) 0
_________________________________________________________________
dense (Dense) (None, 128) 6422656
_________________________________________________________________
dense_1 (Dense) (None, 2) 258
=================================================================
Total params: 6,446,498
Trainable params: 6,446,498
Non-trainable params: 0
_________________________________________________________________
from flask import *
import h5py
from PIL import Image, ImageFile
from io import BytesIO
from werkzeug.utils import secure_filename
import numpy as np
from matplotlib import image
import os
import uuid
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from werkzeug.wrappers import Request, Response
import tensorflow as tf
import numpy as np
IMGSIZE=224
IMGCHANNELS=3*224*224
ALLOWEDEXTENSIONS=['png','jpg','txt','pdf','jfif','jpeg','gif']
classes=('NORMAL','PNEMONIAL')
app=Flask(__name__)
model=load_model(os.path.join(os.getcwd(),"chestXray.h5"))
def preprocess(inputs):
inputs=np.array(inputs)
l=inputs.size
for i in range(l):
j=inputs[i]
j /=255
inputs[i]=j
j=0
return inputs
#Allowed file function
def allowededFile(filename):
return '.' in filename and filename.rsplit('.',1)[1] in ALLOWEDEXTENSIONS
#Flask ROutes
@app.route("/")
def index():
return render_template("chest.html", Prediction=" ")
@app.route("/api/image", methods=['GET','POST'])
def api():
file=request.files['file']
if file.filename=='':
return render_template("chest.html", Prediction="You did not upload a file")
if file and allowededFile(file.filename):
filename=secure_filename(file.filename)
print("***"+filename)
x=[]
ImageFile.LOAD_TRUNCATED_IMAGES=False
img=Image.open(BytesIO(file.read()))
img.load()
img=img.resize((IMGSIZE,IMGSIZE),Image.ANTIALIAS)
x=image.img_to_array(img)
#np.expand_dims(x,axis=-1)
x.reshape((1,224,224,1))
pred=model.predict(x)
lst=decode_predictions(pred,top=3)
items=[]
for item in lst[0]:
items.append({'name':item[1],'probability':float(item[2])})
response={'prediction':items}
return render_template("chest.html",Prediction="The state of this person's lungs is most likely {}".format(response))
else:
return render_template("chest.html",Prediction='Invalid file extension')
if __name__=="__main__":
from werkzeug.serving import run_simple
run_simple('localhost',9000,app)
型号代码:
model=Sequential([
layers.experimental.preprocessing.Rescaling(1./255,input_shape=(imgSize,imgSize,3)),
layers.Conv2D(16,3,padding='same',activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32,3,padding='same',activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64,3,padding='same',activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(numClasses)
])
解决方案
我检查了你的代码,没有发现问题,如果x.shape
是(1,224,224,3)
.
>>> from tensorflow.keras import layers, Sequential
>>> model=Sequential([
layers.experimental.preprocessing.Rescaling(1./255,input_shape=(imgSize,imgSize,3)),
layers.Conv2D(16,3,padding='same',activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32,3,padding='same',activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64,3,padding='same',activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(numClasses)
])
>>> import numpy as np
>>> x = np.zeros((1,224,224,3))
>>> model.predict(x)
2021-05-02 18:35:15.325046: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
array([[0., 0.]], dtype=float32)
>>>
请检查您的 x.shape,如果它与 input_shape 不兼容,您将收到此错误!
以下是模型摘要:
>>> model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
rescaling_2 (Rescaling) (None, 224, 224, 3) 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 224, 224, 16) 448
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 112, 112, 16) 0
_________________________________________________________________
conv2d_7 (Conv2D) (None, 112, 112, 32) 4640
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 56, 56, 32) 0
_________________________________________________________________
conv2d_8 (Conv2D) (None, 56, 56, 64) 18496
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 28, 28, 64) 0
_________________________________________________________________
flatten_2 (Flatten) (None, 50176) 0
_________________________________________________________________
dense_3 (Dense) (None, 128) 6422656
_________________________________________________________________
dense_4 (Dense) (None, 2) 258
=================================================================
Total params: 6,446,498
Trainable params: 6,446,498
Non-trainable params: 0
_________________________________________________________________
>>>
推荐阅读
- java - Media.Setlooping 不连续运行
- javascript - Javascript 模块在 text/javascript html 脚本标签中的使用
- c - 在 C 中使用递归和回溯解决数独领域
- css - 图像按高度缩放
- javascript - WebRTC onnegotiationneeded
- javascript - JQuery 从一个值中设置数据名称
- javascript - 未选择的元素也被突出显示,为什么?
- c# - 在哪里/如何存储或生成客户端机密?
- django - 在 Django 应用程序中,在哪里以及如何设置 LANGUAGE_CODE?
- c++ - 三元运算符内的赋值