python - keras.models.load_model 因 'tags' = train 失败
问题描述
我正在探索 tensorflow 2.0 的 c API。
问题:将模型加载到 python 中时,权重没有恢复,因此模型似乎未经训练。
工作流程:我正在使用 TF 2.0 C api 来处理我的模型的训练。我遵循的一般设置是:
1.使用TF keras api在python中定义模型。
import tensorflow as tf
from tensorflow import keras
model = keras.Sequential([keras.layers.Dense(128,
input_shape=(784,),
activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss="categorical_crossentropy",
metrics=['accuracy'])
keras.experimental.export_saved_model(model,"keras_model")
我正在使用 keras.experimental.export_saved_model(),因为我需要使用 keras.Model.save() 时未保存的“signature_def['train']”。
2. 使用 TF 2.0 C api 在 C 中训练模型 保存的模型然后通过以下方式加载到我的 C 程序中:
TF_LoadSessionFromSavedModel()
随后对其进行训练并保存检查点:
TF_SessionRun()
保存模型会在存储模型的“变量”文件夹中创建新的检查点文件(“checkpoint.index”和“checkpoint.data-00000-of-00001”)。
3.问题在python中重新加载模型 训练后我在python中重新加载我的模型。这就是我发现加载的模型具有与未经训练的模型相对应的重量的地方。我知道这一点是因为当我用 C 语言训练的模型准确预测时,预测是胡言乱语。我通过以下方式加载我的模型:
import tensorflow as tf
from tensorflow import keras
model = keras.experimental.load_from_saved_model("keras_model")
同样,我正在使用 keras.experimental.load_from_saved_model() 因为当我使用 keras.models.load_model() 时,我得到以下 ValueError:
ValueError: Importing a SavedModel with tf.saved_model.load requires a 'tags=' argument if there is more than one MetaGraph. Got 'tags=None', but there are 3 MetaGraphs in the SavedModel with tag sets [['train'], ['eval'], ['serve']]. Pass a 'tags=' argument to load this SavedModel.
如果我将“tags=serve”参数传递给 keras.model.load_model(),我会得到以下 TypeError:
TypeError: load_model() got an unexpected keyword argument 'tags'
由于文件格式,尝试通过 keras.Model.load_weights() 将保存的检查点文件加载到我的模型中会导致 OSErro:
h5py/_objects.pyx in h5py._objects.with_phil.wrapper()
h5py/_objects.pyx in h5py._objects.with_phil.wrapper()
h5py/h5f.pyx in h5py.h5f.open()
OSError: Unable to open file (file signature not found)
问题 通过检查点文件加载训练模型的正确方法是什么?
如何保存不使用 keras.experimental.export_saved_model 的模型并且仍然能够访问 signature_def['train']?
解决方案
推荐阅读
- notepad++ - 如何在 Notepad++ 中隐藏轮廓?
- html - 使用 Font-Family 自定义字体
- sql - 将 Listview 项目添加到表中
- jenkins - 如何有条件地在环境指令中传递凭据绑定
- python - Tensorflow:如何查找 tf.data.Dataset API 对象的大小
- java - ArrayList 运行时打印不正确
- python-3.x - 即使配置文件的配置不这样做,Firefox 也会通过 python + selenium 继续在预览中打开 pdf
- angular - webpack 将 html 内容转换为 url 并尝试加载它
- ruby-on-rails - 将 ActionMailer RoR 模板与 Sendgrid API v3 一起使用
- javascript - 使用服务器上的模态文件更新图像,数据库中的名称