首页 > 解决方案 > Keras fit 和 fit_generator 返回完全不同的结果

问题描述

Keras fit 和 fit_generator 返回完全不同的结果, fit_generator 降低了近 20% 的精度。我确实在数据生成器中使用了 shuffle。我在下面附上了我的 data_generator。谢谢 !

def data_generator(input_x, input_y, batch_size = BATCH_SIZE):
    loopcount = len(input_x) // batch_size
    while True:
        i = random.randint(0, loopcount - 1)
        x_batch = input_x[i*batch_size:(i+1)*batch_size]
        y_batch = input_y[i*batch_size:(i+1)*batch_size]
        yield x_batch, y_batch

我的 model.fit_generator 如下所示:

    model.fit_generator(generator = data_generator(x_train, y_train, batch_size = BATCH_SIZE),steps_per_epoch = len(x_train) // BATCH_SIZE, epochs = 20, validation_data = data_generator(x_val, y_val, batch_size = BATCH_SIZE), validation_steps = len(x_val) // BATCH_SIZE)

标签: debuggingkeras

解决方案


TLDR;使用fit而不是fit_generator

从 TensorFlow 2.x 开始,该fit方法可以将生成器作为输入,因此fit_generator我强烈建议不要使用fit.


推荐阅读