首页 > 解决方案 > 在连体 CNN 上使用 .fit_generator 时出错

问题描述

我们正在尝试拟合 Siamese CNN,并且在我们想要使用 .fit_generator 将数据提供给模型的最后一部分遇到了麻烦。

我们的生成器函数如下所示:

def get_batch(h, w, batch_size = 100):

    anchor =np.zeros((batch_size,h,w,3))
    positive =np.zeros((batch_size,h,w,3))
    negative =np.zeros((batch_size,h,w,3))

    while True:
    #Choose index at random
        index = np.random.choice(n_row, batch_size)
        for i in range(batch_size):
            list_ind = train_triplets.iloc[index[i],]
            #print(list_ind)
            anchor[i] =  train_data[list_ind[0]]
            positive[i] = train_data[list_ind[1]]
            negative[i] = train_data[list_ind[2]]

            anchor = anchor.astype("float32")
            positive = positive.astype("float32")
            negative = negative.astype("float32")

        yield [anchor,positive,negative]



该模型期望获得一个包含 3 个数组的列表作为 Siamese CNN 的输入。但是,我们收到以下错误消息:

Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 3 array(s), but instead got the following list of 1 arrays

如果我们只是手动提供一个包含 3 个数组的列表,那么它就可以工作。这就是为什么我们怀疑错误是由 .fit_generator 函数引起的。我们必须使用 .fit_generator 函数,因为由于内存问题我们无法存储数据。

有人知道这是为什么吗?

提前谢谢。

标签: pythonkerascomputer-visiongeneratorconv-neural-network

解决方案


根据错误,模型需要 3 个数组,而不是 3 个数组的列表。所以改成 yield [anchor,positive,negative]可能yield anchor,positive,negative会奏效。


推荐阅读