python - 从检查点加载权重在 keras 模型中不起作用
问题描述
我要疯了。
我使用 tensorflow keras 定义了一个顺序模型:
model = tf.keras.Sequential([tf.keras.layer.Dense(128,input_shape(784,),activation="relu"),
tf.keras.layer.Dense(10,activation="softmax"])
model.compile(optimizer="adam",loss="mse")
keras.experimental.export_saved_model(model,"keras_model")
我使用 c_api.h在C 程序中训练所述模型
C 程序将权重保存在检查点文件中。
当尝试从检查点文件恢复python中的权重时:
keras.experimental.load_from_saved_model("keras_model/")
#OR
model = tf.keras.Sequential([tf.keras.layer.Dense(128,input_shape(784,),activation="relu"),
tf.keras.layer.Dense(10,activation="softmax"])
model.load_weights("keras_model/variables/variables")
#OR
checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore("keras_model/variables/variables")
我最终得到一个错误,并且没有恢复权重。
我能够恢复体重并继续在我的 C 程序中进行训练
keras.experimental.load_from_saved_model("keras_model/")
WARNING: Logging before flag parsing goes to stderr.
W0918 15:18:04.350199 140418474760000 deprecation.py:323] From <ipython-input-2-06ea110fdc8e>:1: load_from_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
The experimental save and load functions have been deprecated. Please switch to `tf.keras.models.load_model`.
2019-09-18 15:18:04.390271: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 1696040000 Hz
2019-09-18 15:18:04.390913: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x4bf4790 executing computations on platform Host. Devices:
2019-09-18 15:18:04.390961: I tensorflow/compiler/xla/service/service.cc:175] StreamExecutor device (0): Host, Default Version
W0918 15:18:04.436281 140418474760000 deprecation.py:323] From /home/jregalado/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py:1249: NameBasedSaverStatus.__init__ (from tensorflow.python.training.tracking.util) is deprecated and will be removed in a future version.
Instructions for updating:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-2-06ea110fdc8e> in <module>
----> 1 keras.experimental.load_from_saved_model("keras_model/")
~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py in new_func(*args, **kwargs)
322 'in a future version' if date is None else ('after %s' % date),
323 instructions)
--> 324 return func(*args, **kwargs)
325 return tf_decorator.make_decorator(
326 func, new_func, 'deprecated',
~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saved_model_experimental.py in load_from_saved_model(saved_model_path, custom_objects)
425 compat.as_text(constants.VARIABLES_DIRECTORY),
426 compat.as_text(constants.VARIABLES_FILENAME))
--> 427 model.load_weights(checkpoint_prefix)
428 return model
~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in load_weights(self, filepath, by_name)
179 raise ValueError('Load weights is not yet supported with TPUStrategy '
180 'with steps_per_run greater than 1.')
--> 181 return super(Model, self).load_weights(filepath, by_name)
182
183 @trackable.no_automatic_dependency_tracking
~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in load_weights(self, filepath, by_name)
1372 # streaming restore for any variables created in the future.
1373 trackable_utils.streaming_restore(status=status, session=session)
-> 1374 status.assert_nontrivial_match()
1375 return status
1376 if h5py is None:
~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py in assert_nontrivial_match(self)
964 # assert_nontrivial_match and assert_consumed (and both are less
965 # useful since we don't touch Python objects or Python state).
--> 966 return self.assert_consumed()
967
968 def _gather_saveable_objects(self):
~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py in assert_consumed(self)
941 raise AssertionError(
942 "Some objects had attributes which were not restored:{}".format(
--> 943 "".join(unused_attribute_strings)))
944 for trackable in self._graph_view.list_objects():
945 # pylint: disable=protected-access
AssertionError: Some objects had attributes which were not restored:
<tf.Variable 'a/kernel:0' shape=(784, 128) dtype=float32, numpy=
array([[-0.03716458, -0.04911711, -0.01023878, ..., 0.0636776 ,
0.02892563, -0.05542086],
[-0.02324755, -0.07362694, -0.0399951 , ..., 0.0680329 ,
0.05201877, -0.05149256],
[ 0.00954343, 0.05673491, 0.05108347, ..., 0.01994208,
-0.01107961, 0.06192174],
...,
[ 0.07091486, -0.07734856, -0.04417738, ..., 0.01921409,
-0.01908814, -0.05070668],
[ 0.01353646, -0.05189713, -0.01391671, ..., -0.05795977,
0.04801518, 0.00801209],
[-0.05304915, 0.01870193, 0.05657425, ..., -0.06819408,
-0.00760372, -0.0106293 ]], dtype=float32)>: ['a/kernel']
<tf.Variable 'a/bias:0' shape=(128,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>: ['a/bias']
<tf.Variable 'b/kernel:0' shape=(128, 10) dtype=float32, numpy=
array([[-0.1759212 , -0.09282549, -0.11045764, ..., -0.13727605,
-0.02849793, 0.14510198],
[ 0.06857841, -0.01459177, 0.08369003, ..., 0.05089156,
-0.05319159, -0.08594933],
[-0.180914 , -0.18932283, 0.20551099, ..., -0.17210156,
-0.10069884, 0.06433241],
...,
[ 0.09097584, -0.03930017, -0.15125516, ..., 0.02359283,
-0.16158347, -0.13176063],
[-0.04145582, -0.03205152, 0.20097663, ..., -0.15124482,
0.16874255, -0.15434337],
[-0.13188484, 0.04145408, 0.05036192, ..., -0.10489662,
0.12316228, 0.08794598]], dtype=float32)>: ['b/kernel']
<tf.Variable 'b/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>: ['b/bias']
解决方案
推荐阅读
- php - Symfony\Component\Debug\Exception\FatalThrowableError:参数 1 传递给 Tymon\JWTAuth\JWTGuard::login()
- git - 删除一个分支会搞乱从它创建的新分支吗?
- android - build.gradle (app) 中的 buildTypes 参数 (buildConfigField) 是否按字母顺序排序?
- c# - 从另一个类获取 tabControl1.SelectedIndex
- asp.net-mvc - 将工具提示添加到 @Html.Grid 列
- dependency-injection - .net core 依赖注入 vs 很少使用的静态类
- python - 从 json 转义单反斜杠和双反斜杠
- javascript - JavaScript 代码无法使用“toLowerCase”将其与从提示中收到的值等同起来
- python - 使用 suds-py3 时获取最新 SOAP 请求中使用的 http://tempuri.org/ 节点
- c# - c#在循环中动态创建一个EventHandler