首页 > 解决方案 > 如何加载带有自定义损失的keras“历史”对象?

问题描述

所以我定义了我的 keras 模型并使用了 custom_loss 函数来训练模型:

model.compile(optimizer='adam', loss=custom_loss, metrics=[custom_loss])

然后我正在训练模型:

history = model.fit(X_train, y_train, batch_size=1024, epochs=125, validation_split=0.2, shuffle=True)

然后我使用以下代码保存这个历史对象:

with open('history.pkl', 'wb') as file:  
   pickle.dump(history, file)

现在,当我尝试按如下方式读取历史对象时:

with open('history.pkl', 'rb') as file:
    history = pickle.load(file)

我收到以下错误:

ValueError:未知损失函数:custom_loss

如何读取历史对象?当我不使用 custom_loss 函数时,我没有收到此错误。我正在使用 keras 2.2.4 和 tensorflow 1.15.5

编辑:按要求完成错误回溯: 在此处输入图像描述

标签: pythontensorflowkerasneural-networkpickle

解决方案


对于大多数用例,您不想序列化历史对象。您通常感兴趣的是 history.history,它是日志/指标/损失等的字典。

试试看:

pickle.dump(history.history, file)

更完整的答案是返回的历史对象是一个 tf.keras.callbacks.History,它是 tf.keras.callbacks.Callback 的子类。回调本身有一个对模型的引用,然后它对各种东西都有引用,包括自定义对象,比如你的自定义损失。Keras 自定义对象的序列化是另一个大话题...... tldr 序列化 Keras 模型的推荐方法是不使用 pickle。


推荐阅读