首页 > 解决方案 > keras.models.load_model 因 'tags' = train 失败

问题描述

我正在探索 tensorflow 2.0 的 c API。

问题:将模型加载到 python 中时,权重没有恢复,因此模型似乎未经训练。

工作流程:我正在使用 TF 2.0 C api 来处理我的模型的训练。我遵循的一般设置是:

1.使用TF keras api在python中定义模型。

import tensorflow as tf
from tensorflow import keras

model = keras.Sequential([keras.layers.Dense(128,
                                             input_shape=(784,),
                                             activation='relu'),
                          keras.layers.Dense(10, activation='softmax')
                        ])
model.compile(optimizer='adam',
              loss="categorical_crossentropy",
              metrics=['accuracy'])

keras.experimental.export_saved_model(model,"keras_model")

我正在使用 keras.experimental.export_saved_model(),因为我需要使用 keras.Model.save() 时未保存的“signature_def['train']”。

2. 使用 TF 2.0 C api 在 C 中训练模型 保存的模型然后通过以下方式加载到我的 C 程序中:

TF_LoadSessionFromSavedModel()

随后对其进行训练并保存检查点:

TF_SessionRun()

保存模型会在存储模型的“变量”文件夹中创建新的检查点文件(“checkpoint.index”和“checkpoint.data-00000-of-00001”)。

3.问题在python中重新加载模型 训练后我在python中重新加载我的模型。这就是我发现加载的模型具有与未经训练的模型相对应的重量的地方。我知道这一点是因为当我用 C 语言训练的模型准确预测时,预测是胡言乱语。我通过以下方式加载我的模型:

import tensorflow as tf
from tensorflow import keras

model = keras.experimental.load_from_saved_model("keras_model")

同样,我正在使用 keras.experimental.load_from_saved_model() 因为当我使用 keras.models.load_model() 时,我得到以下 ValueError:

ValueError: Importing a SavedModel with tf.saved_model.load requires a 'tags=' argument if there is more than one MetaGraph. Got 'tags=None', but there are 3 MetaGraphs in the SavedModel with tag sets [['train'], ['eval'], ['serve']]. Pass a 'tags=' argument to load this SavedModel.

如果我将“tags=serve”参数传递给 keras.model.load_model(),我会得到以下 TypeError:

TypeError: load_model() got an unexpected keyword argument 'tags'

由于文件格式,尝试通过 keras.Model.load_weights() 将保存的检查点文件加载到我的模型中会导致 OSErro:

h5py/_objects.pyx in h5py._objects.with_phil.wrapper()
h5py/_objects.pyx in h5py._objects.with_phil.wrapper()
h5py/h5f.pyx in h5py.h5f.open()
OSError: Unable to open file (file signature not found)

问题 通过检查点文件加载训练模型的正确方法是什么?

如何保存使用 keras.experimental.export_saved_model 的模型并且仍然能够访问 signature_def['train']?

标签: pythonctensorflowkeras

解决方案


推荐阅读