首页 > 解决方案 > TypeError:fit_generator() 为参数“steps_per_epoch”获得了多个值

问题描述

我正在尝试训练 keras 模型。下面是火车模型的功能。

def train_model(input_videos, video_label, mapping, micro_expressions, val_x=None, val_micro=None, val_y=None, return_best=False):
    print("train_model")
    input_videos = np.asarray(input_videos)
    data_x, data_y, data_mapping, data_micro = H1_preprocessing(input_videos, video_label, mapping, micro_expressions)
    model, callbacks_lst = build_model()
    print("data_x: ", data_x.shape)
    print("data_y: ", data_y.shape)
    if val_x is not None and val_y is not None and val_micro is not None:
        hist = model.fit_generator([data_x, data_micro[0], data_micro[1], data_micro[2]], data_y, steps_per_epoch = 2, epochs = 20, verbose = 2, callbacks = callbacks_lst, validation_data=([val_x, val_micro[0], val_micro[1], val_micro[2]], val_y), use_multiprocessing=True, shuffle=True)
        print(hist) 
    else:
        hist = model.fit_generator([data_x, data_micro[0], data_micro[1], data_micro[2]], data_y, steps_per_epoch = 2, epochs = 20, verbose = 2, callbacks = callbacks_lst, use_multiprocessing=True, shuffle=True)
        print(hist)

    if return_best:
        print("Applying weights")
        model.load_weights("weights.best.hdf5")
        model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

fit_generator() 函数会产生错误。我不知道在正确的位置提供正确的参数。它显示以下错误。

Traceback (most recent call last):
  File "concatenated_classifier.py", line 697, in <module>
    trained_model = train_model(list_gray_train_videos, list_train_label, clips_mapping, micro_expressions, None, None, None, True)
  File "concatenated_classifier.py", line 613, in train_model
    hist = model.fit_generator([data_x, data_micro[0], data_micro[1], data_micro[2]], data_y, steps_per_epoch = 2, epochs = 20, verbose = 2, callbacks = callbacks_lst, use_multiprocessing=True, shuffle=True)
  File "C:\Users\Me\Anaconda3\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
TypeError: fit_generator() got multiple values for argument 'steps_per_epoch'

建议我在哪里提供 steps_per_epoch 和 data_y。

标签: pythonmachine-learningkeras

解决方案


的第二个参数fit_generatorsteps_per_epoch。请参阅此处的文档:https ://www.tensorflow.org/api_docs/python/tf/keras/Model

fit_generator(
    generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None,
    validation_data=None, validation_steps=None, validation_freq=1,
    class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False,
    shuffle=True, initial_epoch=0
)

因此,当你这样称呼它时:

hist = model.fit_generator(
    [data_x, data_micro[0], data_micro[1], data_micro[2]],
    data_y,
    steps_per_epoch=2,
    ...
)

data_y为此参数提供位置参数,并将其作为关键字参数提供。

请注意,model.fit现在也支持生成器,因此不推荐使用此方法。


推荐阅读