python - 使用泡菜保存模型时出现“TypeError:无法泡菜'_thread.RLock'对象”
问题描述
我正在尝试将我的 keras 模型保存到 pickle 文件中,但出现此错误。有什么办法可以解决?或者保存和加载模型的更好方法是什么?我正在二进制预测 480x640 灰度图像。
遵循我的代码:
def trainModel(data):
batch_size = 3
img_height = 480
img_width = 640
trainDataset = tf.keras.preprocessing.image_dataset_from_directory(
data,
validation_split=0.25,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size,
#class_names={"nao_doentes", "doentes"}
)
valDataset = tf.keras.preprocessing.image_dataset_from_directory(
data,
validation_split=0.25,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size,
#class_names={"nao_doentes", "doentes"}
)
AUTOTUNE = tf.data.experimental.AUTOTUNE
trainDataset = trainDataset.cache().prefetch(buffer_size=AUTOTUNE)
valDataset = valDataset.cache().prefetch(buffer_size=AUTOTUNE)
num_classes = 2
model = tf.keras.Sequential([
layers.experimental.preprocessing.Rescaling(1./255),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes)
])
model.compile(
optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=['accuracy']
)
model.fit(
trainDataset,
validation_data=valDataset,
epochs=10
)
return model
model = trainModel(training_data)
with open('model.sav', 'wb') as f:
pickle.dump(model, f)
with open('model.sav', 'rb') as f:
model = pickle.load(f)
testing = np.ndarray(shape=(1, 1, 480, 640), dtype=np.float32)
image = load_img(os.path.join(test_data, "doentes/doente_6.jpg"), target_size=(480,640))
x = img_to_array(image)
x = np.expand_dims(x, axis=0)
testing = np.vstack([x])
print(model.predict(testing))
另外,当问题与图像分类情况有关时,您是否愿意提供一些良好实践和解释的良好来源提供建议?我是该地区的新手,因此在搜索和链接不同来源的信息时我有点挣扎。
解决方案
一般来说,pickle 在为 pytorch、tensorflow 和 keras 保存 ml 模型权重时存在问题。要保存您的 keras 模型,请查看他们的教程
具体来说,尝试在 keras 中使用函数 save 和 load_module:
model.save('path/to/location')
reconstructed_model = keras.models.load_model("path/to/location")
推荐阅读
- c# - C# 使用数组启用/禁用按钮,Win 表单/MVP
- css - 如果我为侧边栏设置位置:固定,图标就会消失
- sql - 用于合并重复项的存储过程 Firebird
- vhdl - 在 VHDL 中为多个 IP 使用相同的组件
- reactjs - 如何在打字稿界面中使用两种类型的事件
- python-3.x - 在 pyspark 中使用自定义顺序选择最大/最大值
- python - 使用 requirements.txt 安装时 pip 拒绝二进制文件
- javascript - 跨站点 postMessage 未被 addEventListener 拾取
- javascript - 模式打开时如何允许背景div可滚动
- python-3.x - 如何通过 Python 和 pytest 获取 dict “Current”