tensorflow - 将 Keras 分类器中的权重导入 TF 对象检测 API
问题描述
我有一个使用 keras 训练的分类器,效果非常好。它使用keras.applications.MobileNetV2
.
这个分类器在大约 200 个类别上训练有素,并且具有很高的准确性。
但是,我想将此分类器中的特征提取层用作对象检测模型的一部分。
我一直在使用 Tensorflow 对象检测 API,并研究了 SSDLite+MobileNetV2 模型。我可以开始进行训练,但是训练非常缓慢,并且大部分损失来自分类阶段。
我想做的是将我的 keras.h5
模型中的权重分配给 Tensorflow 中 MobileNetV2 的特征提取层,但我不确定最好的方法。
我可以h5
轻松加载文件,并获取图层名称列表:
import keras
keras_model = keras.models.load_model("my_classifier.h5")
keras_names = [l.name for l in keras_model.layers]
print(keras_names)
我还可以从对象检测 API 恢复 tensorflow 检查点并导出具有权重的层:
tf.reset_default_graph()
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('models/model.ckpt.meta')
what = new_saver.restore(sess, 'models/model.ckpt')
tf_names = []
for op in sess.graph.get_operations():
if "MobilenetV2" in op.name and "Assign" in op.name:
tf_names.append(op.name)
print(tf_names)
我似乎无法在 keras 和 tensorflow 的层名称之间找到很好的匹配。即使我可以,我也不确定接下来的步骤。
如果有人能给我一些关于解决这个问题的最佳方法的建议,我将不胜感激。
更新:
我遵循了 Sharky 的建议,稍作修改:
new_saver = tf.train.import_meta_graph(os.path.join(keras_checkpoint_dir, 'keras_model.ckpt.meta'))
new_saver.restore(sess, os.path.join(keras_checkpoint_dir, tf.train.latest_checkpoint(keras_checkpoint_dir)))
但是不幸的是,我现在收到此错误:
NotFoundError(参见上面的回溯):从检查点恢复失败。这很可能是由于检查点中缺少变量名称或其他图形键。请确保您没有根据检查点更改预期的图表。原始错误:
关键 FeatureExtractor/MobilenetV2/expanded_conv_6/project/BatchNorm/gamma not found in checkpoint [[node save/RestoreV2_295 (defined at :7) = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task :0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2_295/tensor_names, save/RestoreV2_295/shape_and_slices)]] [[{{节点保存/RestoreV2_196/_393}} = _Recvclient_terminated=false, recv_device="/ job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_789_save /RestoreV2_196", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]]
关于如何摆脱这个错误的任何想法?
解决方案
您可以使用tf.keras.estimator.model_to_estimator
estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir=path)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, os.path.join(path/keras, tf.train.latest_checkpoint(path/keras)))
print(tf.global_variables())
这应该可以完成这项工作。请注意,它将在最初指定的路径内创建一个子目录。
推荐阅读
- javascript - 如何在不丢失样式的情况下打印vue组件的一部分
- angular - Angular rxjs Observable.timer 不是具有导入功能的函数
- javascript - Twitch API 和显示在线状态
- regex - java regex pattern.compile Vs 匹配器
- sql - 如何使用联接从两个表中获取不匹配的记录
- java - 超类中的私有变量如何在这里的子类中继承?
- asp.net-mvc - 整数值的数据注释验证
- angular - android 版本的 Ionic cordova 插件
- html - 如何将背景图像浮动到左侧?
- javascript - 如何在节点控制台中记录深度嵌套的对象