python - TypeError:fit_generator() 为参数“steps_per_epoch”获得了多个值
问题描述
我正在尝试训练 keras 模型。下面是火车模型的功能。
def train_model(input_videos, video_label, mapping, micro_expressions, val_x=None, val_micro=None, val_y=None, return_best=False):
print("train_model")
input_videos = np.asarray(input_videos)
data_x, data_y, data_mapping, data_micro = H1_preprocessing(input_videos, video_label, mapping, micro_expressions)
model, callbacks_lst = build_model()
print("data_x: ", data_x.shape)
print("data_y: ", data_y.shape)
if val_x is not None and val_y is not None and val_micro is not None:
hist = model.fit_generator([data_x, data_micro[0], data_micro[1], data_micro[2]], data_y, steps_per_epoch = 2, epochs = 20, verbose = 2, callbacks = callbacks_lst, validation_data=([val_x, val_micro[0], val_micro[1], val_micro[2]], val_y), use_multiprocessing=True, shuffle=True)
print(hist)
else:
hist = model.fit_generator([data_x, data_micro[0], data_micro[1], data_micro[2]], data_y, steps_per_epoch = 2, epochs = 20, verbose = 2, callbacks = callbacks_lst, use_multiprocessing=True, shuffle=True)
print(hist)
if return_best:
print("Applying weights")
model.load_weights("weights.best.hdf5")
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
fit_generator() 函数会产生错误。我不知道在正确的位置提供正确的参数。它显示以下错误。
Traceback (most recent call last):
File "concatenated_classifier.py", line 697, in <module>
trained_model = train_model(list_gray_train_videos, list_train_label, clips_mapping, micro_expressions, None, None, None, True)
File "concatenated_classifier.py", line 613, in train_model
hist = model.fit_generator([data_x, data_micro[0], data_micro[1], data_micro[2]], data_y, steps_per_epoch = 2, epochs = 20, verbose = 2, callbacks = callbacks_lst, use_multiprocessing=True, shuffle=True)
File "C:\Users\Me\Anaconda3\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
TypeError: fit_generator() got multiple values for argument 'steps_per_epoch'
建议我在哪里提供 steps_per_epoch 和 data_y。
解决方案
的第二个参数fit_generator
是steps_per_epoch
。请参阅此处的文档:https ://www.tensorflow.org/api_docs/python/tf/keras/Model
fit_generator(
generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None,
validation_data=None, validation_steps=None, validation_freq=1,
class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False,
shuffle=True, initial_epoch=0
)
因此,当你这样称呼它时:
hist = model.fit_generator(
[data_x, data_micro[0], data_micro[1], data_micro[2]],
data_y,
steps_per_epoch=2,
...
)
您data_y
为此参数提供位置参数,并将其作为关键字参数提供。
请注意,model.fit
现在也支持生成器,因此不推荐使用此方法。
推荐阅读
- airflow - 启动和停止气流?
- delphi - 使用运行时库的 Delphi 64 位调试有错误的堆栈帧处于活动状态
- tensorflow - 密集函数输入形状与 keras 中的一个热编码训练数据不匹配
- javascript - 在 AnyChart 极坐标图中为每个象限设置背景
- angular - Angular 4 -> 使用自签名证书的 REST API 身份验证
- c++ - 名称查找和运算符重载如何工作?
- python - 输入5行数据,分类(预测)Keras LSTM中的第6行
- mongodb - 使用 'arrayfilters' 使用 Jongo 更新子数组
- string - 如何将字符串大写?
- android - error.org.jsonException.“opening_hours”没有值