python - 4维数据经过卷积神经网络后变成2维
问题描述
我正在训练一个具有两种输入类型的神经元网络:图像和 BR(红色上的蓝色,它是一种非图像特征,如身高、体重......)。为此,我在 keras 中使用 fit 函数,并将图像转换为列表以供输入。但我不知道为什么图像列表,它有 4 个维度的形状在适应时变成了 2 个维度,我得到了如下错误:
检查输入时出错:预期的 dense_1_input 具有 3 个维度,但得到的数组具有形状 (1630, 1)
当我将图像列表转换为数组时,我检查了 image_array 的形状,它正好有 4 个维度(特别是它的形状是 1630、60、60、3)。即使在 fit 函数之前,它仍然具有相同的形状。所以我真的不知道为什么形状变成了(1630,1)。谁能为我解释一下?
这是我的代码:
from keras.utils.np_utils import to_categorical
import pandas as pd
import numpy as np
import os
from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential, Model
from keras.layers import Input, Activation, Dropout, Flatten, Dense,Concatenate, concatenate,Reshape, BatchNormalization, Merge
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.optimizers import Adagrad
from sklearn import preprocessing
from scipy.misc import imread
import time
from PIL import Image
import cv2
img_width, img_height = 60, 60
img_list = []
BR_list = []
label_list = []
data_num = 1630
folder1 = "cut2/train/sugi/"
folder2 = "cut2/train/hinoki/"
def imgConvert(file_path):
img = imread(file_path,flatten = True)
img = np.arange(1*3*60*60).reshape((60,60,3))
img = np.array(img).reshape(60,60,3)
img = img.astype("float32")
return img
def B_and_R(img_path):
img = cv2.imread(img_path)
B = 0
R = 0
for i in range(25,35):
#print(i)
for j in range(25,35):
B = B+img[i,j,0]
R = R+img[i,j,2]
#(j)
#(img[i,j])
ave_B = B/100
ave_R = R/100
BR = ave_B/ave_R
return BR
def getData(path,pollen):
for the_file in os.listdir(path):
#print(the_file)
file_path = os.path.join(path, the_file)
B_over_R = B_and_R(file_path)
img_arr = imgConvert(file_path)
#writer.writerow([img_arr,B_over_R,"sugi"])
img_list.append(img_arr)
BR_list.append(B_over_R)
lb = np.zeros(2)
if pollen == "sugi":
lb[0] +=1
else:
lb[1] +=1
label_list.append(lb)
if __name__ == '__main__':
getData(folder1,"sugi")
getData(folder2,"hinoki")
img_arr = np.array(img_list)
print(img_arr.shape)
#.reshape(img_list[0],1,img_width,img_height)
img_arr.astype("float32")
img_arr /= 255
print(img_arr.shape)
img_array = np.expand_dims(img_arr, axis = 0)
img_array = img_array[0,:,:,:,:]
print(img_array.shape)
"""
datagen = ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)
datagen.fit(img_array)
"""
#img_array = img_array.reshape(img_array[0],1,img_width,img_height)
print(img_array.shape)
label_arr = np.array(label_list)
print(label_arr.shape)
#label_array = np.expand_dims(label_arr, axis = 0)
#label_array = label_array[0,:,:,:,:]
BR_arr = np.array(BR_list)
print(BR_arr.shape)
#BR_array = np.expand_dims(BR_arr, axis = 0)
#BR_array = BR_array[0,:,:,:,:]
#print(len([img_arr,BR_arr]))
input_tensor = Input(shape=(img_width, img_height,3))
vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)
# FC層の作成
top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16.output_shape[1:]))
#print(top_model.summary())
# VGG16とFC層を結合してモデルを作成
branch1 = Model(input=vgg16.input, output=top_model(vgg16.output))
#model.summary()
print(branch1.summary())
branch2 = Sequential()
branch2.add(Dense(1, input_shape=(data_num,1), activation='sigmoid'))
#branch1.add(Reshape(BR.shape, input_shape = BR.shape))
branch2.add(BatchNormalization())
branch2.add(Flatten())
print(branch2.summary())
merged = Merge([branch1, branch2], mode = "concat")
model = Sequential()
model.add(merged)
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(2, activation='softmax'))
#last_model = Model(input = [branch1.input,branch2.input],output=model())
print(model.summary())
model.compile(loss='categorical_crossentropy',
optimizer=optimizers.SGD(lr=1e-3, momentum=0.9),
metrics=['accuracy'])
print(img_array.shape)
model.fit([img_array,BR_arr], label_arr,
epochs=5, batch_size=100, verbose=1)
解决方案
好的,那么问题是输入形状。
虽然分支 2 的数据是 2D (batch, 1)
,但您的模型也应该有一个 2D 输入:input_shape = (1,)
. (批量大小在 中被忽略input_shape
)
推荐阅读
- spring-boot - Sprint Data JPA - save() 方法在表中插入重复项
- python - 我应该如何更改“urlpatterns”中的路径?
- node.js - 我保存了从响应中获得的音频文件,但它不播放
- html - 在移动屏幕上,我的 div 显示了不需要的白色边框
- javascript - Daterangepicker 如何获取预定义日期范围的值
- arrays - 在ant design table中使用多维数组
- c# - UWP SQLite 并发说明
- async-await - 条带检索客户方法是返回客户还是将其传递给回调?
- javascript - 运行 python 应用程序的最简单方法是从 Web 应用程序接收数据,处理/操作它,然后将其推回前端
- python - Pycharm 失去了视图和模板之间的连接