python - 在连体 CNN 上使用 .fit_generator 时出错
问题描述
我们正在尝试拟合 Siamese CNN,并且在我们想要使用 .fit_generator 将数据提供给模型的最后一部分遇到了麻烦。
我们的生成器函数如下所示:
def get_batch(h, w, batch_size = 100):
anchor =np.zeros((batch_size,h,w,3))
positive =np.zeros((batch_size,h,w,3))
negative =np.zeros((batch_size,h,w,3))
while True:
#Choose index at random
index = np.random.choice(n_row, batch_size)
for i in range(batch_size):
list_ind = train_triplets.iloc[index[i],]
#print(list_ind)
anchor[i] = train_data[list_ind[0]]
positive[i] = train_data[list_ind[1]]
negative[i] = train_data[list_ind[2]]
anchor = anchor.astype("float32")
positive = positive.astype("float32")
negative = negative.astype("float32")
yield [anchor,positive,negative]
该模型期望获得一个包含 3 个数组的列表作为 Siamese CNN 的输入。但是,我们收到以下错误消息:
Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 3 array(s), but instead got the following list of 1 arrays
如果我们只是手动提供一个包含 3 个数组的列表,那么它就可以工作。这就是为什么我们怀疑错误是由 .fit_generator 函数引起的。我们必须使用 .fit_generator 函数,因为由于内存问题我们无法存储数据。
有人知道这是为什么吗?
提前谢谢。
解决方案
根据错误,模型需要 3 个数组,而不是 3 个数组的列表。所以改成
yield [anchor,positive,negative]
可能yield anchor,positive,negative
会奏效。