首页 > 解决方案 > 无法将 keras 模型保存在数据块中

问题描述

我正在保存 keras 模型

model.save('model.h5')

在数据块中,但模型不保存,

我也试过保存在这里/tmp/model.h5提到

但模型不保存。

保存单元格执行但是当我加载模型时它显示没有model.h5文件可用。

当我这样做时dbfs_model_path = 'dbfs:/FileStore/models/model.h5' dbutils.fs.cp('file:/tmp/model.h5', dbfs_model_path)

或尝试加载模型

tf.keras.models.load_model("file:/tmp/model.h5")

我收到错误消息java.io.FileNotFoundException: File file:/tmp/model.h5 does not exist

标签: kerasdatabricksazure-databricks

解决方案


问题是 Keras 被设计为仅处理本地文件,因此它不理解 URI,例如dbfs:/、 或file:/. 因此,您需要使用本地路径来保存和加载操作,然后将文件复制到 DBFS 或从 DBFS 复制文件(不幸/dbfs的是,由于它的工作方式,它不能很好地与 Keras 配合使用)。

下面的代码工作得很好。请注意,dbfs:/orfile:/仅在对dbutils.fs命令的调用中使用 - Keras 的东西使用本地文件的名称。

  • 创建模型并在本地保存为/tmp/model-full.h5
from tensorflow.keras.applications import InceptionV3
model = InceptionV3(weights="imagenet")
model.save('/tmp/model-full.h5')
  • 将数据复制到 DBFSdbfs:/tmp/model-full.h5并检查它:
dbutils.fs.cp("file:/tmp/model-full.h5", "dbfs:/tmp/model-full.h5")
display(dbutils.fs.ls("/tmp/model-full.h5"))
  • 从 DBFS 复制文件/tmp/model-full2.h5并加载它:
dbutils.fs.cp("dbfs:/tmp/model-full.h5", "file:/tmp/model-full2.h5")
from tensorflow import keras
model2 = keras.models.load_model("/tmp/model-full2.h5")

推荐阅读