tensorflow - Keras ValueError 尝试加载模型
问题描述
我正在使用 Anaconda Navigator,确切地说是 Jupyter。
import tensorflow as tf
from tensorflow import keras
print(tf.__version__)
>>> 1.14.0
这是我的模型
def create_model():
model = tf.keras.Sequential([
keras.layers.Dense(86, activation='relu', kernel_regularizer=keras.regularizers.l2(0.0001),input_shape=(129,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(142, activation='relu', kernel_regularizer=keras.regularizers.l2(0.0001)),
keras.layers.Dropout(0.2),
keras.layers.Dense(4, activation='softmax')
])
return model
model = create_model()
# Display the model's architecture
model.summary()
在训练、预测和评估我的模型之后,我决定使用
model.save('/Users/Jennifer/myproject/my_model.h5')
我用 h5py 文件检查了目录和文件夹。我决定使用
new_model1 = tf.keras.models.load_model('/Users/Jennifer/myproject/my_model.h5')
我有一个错误
ValueError: Unknown entries in loss dictionary: ['class_name', 'config']. Only expected following keys: ['dense_17']
请帮我。我应该怎么办?我几乎花了一整天的时间来解决这个问题。谢谢
解决方案
这是一个只加载权重的解决方法:
#!/usr/bin/env python3
from tensorflow import keras
import os
def create_model():
model = keras.Sequential([
keras.layers.Dense(86, activation='relu', kernel_regularizer=keras.regularizers.l2(0.0001),input_shape=(129,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(142, activation='relu', kernel_regularizer=keras.regularizers.l2(0.0001)),
keras.layers.Dropout(0.2),
keras.layers.Dense(4, activation='softmax')
])
return model
if os.path.exists("junk.h5"):
model = create_model()
model.load_weights("junk.h5")
else:
model = create_model()
model.compile(optimizer=keras.optimizers.Adam(0.0001), loss=keras.losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
model.save("junk.h5")
另一种解决方法是在没有优化器的情况下保存模型
model.save("junk.h5", include_optimizer=False)
看起来您正在使用的损失函数创建了一个包含无效键的字典。这听起来像是 keras/tensorflow 中的错误。这就是 colab 可能工作的原因,因为它使用的是更新版本。
推荐阅读
- c++ - 为什么 const char* const & = "hello" 可以编译?
- c# - 在c#中返回元组字典作为函数的返回类型
- angular - 如何正确实施“photo.service”?
- javascript - 锚标签-在 iphone7/6s plus 中没有在 mousedown 上获取日志
- cordova - 如何在ionic 2中使用couchbase-lite cordova插件实现同步功能?
- python - 防止 requests-mock 降低我的帖子数据的大小写
- javascript - 使用 websocket 时消息被延迟
- node.js - 如何使用正文有效负载发出 HTTP/2 请求?
- javascript - 从Javascript中的关联数组中删除项目
- kubernetes - 使用 TcpDiscoveryKubernetesIpFinder 在 Kubernetes 集群中无法发现 Ignite