首页 > 解决方案 > 在 keras 训练期间使用 fit_generator 的问题

问题描述

我正在处理非常大的文本数据集。我考虑过使用model.fit_generatormethod 而不是 simple model.fit,所以我尝试使用这个生成器:

def TrainGenerator(inp, out):
  for i,o in zip(inp, out):
    yield i,o

当我在训练期间尝试使用它时,使用:

#inp_train, out_train are lists of sequences padded to 50 tokens
model.fit_generator(generator = TrainGenerator(inp_train,out_train),
                    steps_per_epoch = BATCH_SIZE * 100,
                    epochs = 20,
                    use_multiprocessing = True)

我得到:

ValueError: Error when checking input: expected embedding_input to have shape (50,) but got array with shape (1,)

现在,我尝试使用简单的model.fit方法,效果很好。所以,我认为我的问题出在生成器上,但是由于我是使用生成器的新手,我不知道如何解决它。完整的模型摘要是:

Layer (type)                 Output Shape            
===========================================
Embedding (Embedding)      (None, 50, 400)           
___________________________________________
Bi_LSTM_1 (Bidirectional)  (None, 50, 1024)          
___________________________________________
Bi_LSTM_2 (Bidirectional)  (None, 50, 1024)          
___________________________________________
Output (Dense)             (None, 50, 153)           
===========================================

#编辑 1

第一条评论以某种方式触发了我。我意识到我误解了发电机的工作原理。我的生成器的输出是一个形状为 50 的列表,而不是一个形状为 50 的 N 个列表的列表。所以我深入研究了 keras 文档并找到了这个。所以,我改变了我的工作方式,这是作为生成器工作的类:

class BatchGenerator(tf.keras.utils.Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        
        return batch_x, to_categorical(batch_y,num_labels)

to_categorical函数在哪里:

def to_categorical(sequences, categories):
    cat_sequences = []
    for s in sequences:
        cats = []
        for item in s:
            cats.append(np.zeros(categories))
            cats[-1][item] = 1.0
        cat_sequences.append(cats)
    return np.array(cat_sequences)

所以,我现在注意到的是我的网络的良好性能提升,每个时期现在持续一半。发生这种情况是因为我有更多的可用 RAM,因为现在我没有将所有数据集加载到内存中吗?

标签: pythonkerasneural-network

解决方案


推荐阅读