python - 如何加载带有自定义损失的keras“历史”对象?
问题描述
所以我定义了我的 keras 模型并使用了 custom_loss 函数来训练模型:
model.compile(optimizer='adam', loss=custom_loss, metrics=[custom_loss])
然后我正在训练模型:
history = model.fit(X_train, y_train, batch_size=1024, epochs=125, validation_split=0.2, shuffle=True)
然后我使用以下代码保存这个历史对象:
with open('history.pkl', 'wb') as file:
pickle.dump(history, file)
现在,当我尝试按如下方式读取历史对象时:
with open('history.pkl', 'rb') as file:
history = pickle.load(file)
我收到以下错误:
ValueError:未知损失函数:custom_loss
如何读取历史对象?当我不使用 custom_loss 函数时,我没有收到此错误。我正在使用 keras 2.2.4 和 tensorflow 1.15.5
解决方案
对于大多数用例,您不想序列化历史对象。您通常感兴趣的是 history.history,它是日志/指标/损失等的字典。
试试看:
pickle.dump(history.history, file)
更完整的答案是返回的历史对象是一个 tf.keras.callbacks.History,它是 tf.keras.callbacks.Callback 的子类。回调本身有一个对模型的引用,然后它对各种东西都有引用,包括自定义对象,比如你的自定义损失。Keras 自定义对象的序列化是另一个大话题...... tldr 序列化 Keras 模型的推荐方法是不使用 pickle。
推荐阅读
- node.js - 为什么 expressjs POST 请求返回 [object Object]?
- arrays - 来自 webhook 的无效响应:无法将 JSON 转换为 ExecuteHttpResponse
- aws-lambda - 用于突变的 AWS Appsync Lambda 自定义解析器
- flutter - Flutter 转换为没有“[]”的字符串
- django - 表格不显示模型中的数据
- python - 将字典列表中的值从字符串转换为整数
- sql - 插入偶尔丢失的行
- qgis - Qgis:为报告表计算不同类型线的距离
- dynamic - 方案 - 动态范围 - 为什么这是返回值?
- python - 在 RBM 损失训练中损失没有减少