python - 自定义数据生成器 keras 模型预期 2 个数组,但收到 1 个
问题描述
所以我试图让这个自定义生成器正常工作,但似乎有问题。当我尝试使用 gen 时,似乎生成器工作正常。next () 它正在生产我想要的东西。但是,它可能不会以我认为应该的形状制成。
# Image processing
def preprocess_image(image_path):
img = image.load_img(image_path, target_size=(224, 224))
img = image.img_to_array(img)
img = preprocess_input(img)
return img
def image_generator(data, batch_size):
datagen_args = dict(horizontal_flip=True)
datagen = ImageDataGenerator(**datagen_args)
while True:
for i in range(0, len(data) // batch_size):
# get the label and the imagepath
imgpath, label = data[i]
# Process the image
img = preprocess_image(imgpath)
img = datagen.random_transform(img)
#img = np.expand_dims(img, axis=0)
# add a 0 for a dummy variable
dummy_label = np.zeros(len(label))
x_data = np.array([img, label])
yield x_data, dummy_label
# Prepare data need a array [image, label]
X = [] # hold the data before processing
Y = []
IMAGE_DIR = 'dataset/gt_bbox'
for file in os.listdir(IMAGE_DIR):
file_path = os.path.join(IMAGE_DIR, file)
label = int(file.split('_')[0])
X.append(file_path)
Y.append(label)
# Convert to catigorical
Y = to_categorical(Y)
image_dataset = []
for i in range(0,len(X)):
image_dataset.append([X[i], Y[i]])
# Split to train test data
train, val = train_test_split(image_dataset)
BATCHSIZE = 32
imggen = image_generator(train, BATCHSIZE)
valgen = image_generator(val, BATCHSIZE)
model.fit_generator(imggen,
steps_per_epoch=1000,
epochs=10,
validation_data=valgen,
validation_steps=300,
verbose=1)
我的模型是这样设置的
input_images = Input(shape=(224, 224, 3), name='input_image') # input layer for images
input_labels = Input(shape=(1,), name='input_label') # input layer for labels
embeddings = base_network([input_images]) # output of network -> embeddings
labels_plus_embeddings = Concatenate(axis=-1)([input_labels, embeddings]) # concatenating the labels + embeddings
model = Model(inputs=[input_images, input_labels], outputs=labels_plus_embeddings)
我在构建模型的方式上可能是错误的,但对我来说似乎是正确的。
错误信息
ValueError:检查模型输入时出错:您传递给模型的 Numpy 数组列表不是模型预期的大小。预计会看到 2 个数组,但得到了以下 1 个数组的列表: [array([[array([[[-0.56078434, -0.52156866, -0.4980392 ], [-0.56078434, -0.52156866, -0.4980392 ], [-0.56078434, -0.52156866, -0.4980392 ], ..., [-0.5764706 , -0.545098 ...
解决方案
推荐阅读
- android - 渲染问题,我该如何解决?
- sql - 如何屏蔽 Redshift 中的列?
- php - 如何修复大型数据库从连接表中查找第二个最后修改结果的“超时”错误?
- javascript - 找不到模块“内部/fs”无法按照现有解决方案工作
- c# - 组合两个表并将它们的查询写入列表
- java - Locale.getDefault() 究竟检索了什么?
- c++ - 从 SQL Server 的 Azure 帐户获取 IP
- dns - 将域添加到虚拟机
- python - 在我的瀑布图中,“虚拟第三轴”的正确 matplotlib 变换是什么?
- markdown - jekyll 中是否有“kramdownify”过滤器?