首页 > 解决方案 > 将 Keras 分类器中的权重导入 TF 对象检测 API

问题描述

我有一个使用 keras 训练的分类器,效果非常好。它使用keras.applications.MobileNetV2.

这个分类器在大约 200 个类别上训练有素,并且具有很高的准确性。

但是,我想将此分类器中的特征提取层用作对象检测模型的一部分。

我一直在使用 Tensorflow 对象检测 API,并研究了 SSDLite+MobileNetV2 模型。我可以开始进行训练,但是训练非常缓慢,并且大部分损失来自分类阶段。

我想做的是将我的 keras.h5模型中的权重分配给 Tensorflow 中 MobileNetV2 的特征提取层,但我不确定最好的方法。

我可以h5轻松加载文件,并获取图层名称列表:

import keras

keras_model = keras.models.load_model("my_classifier.h5")

keras_names = [l.name for l in keras_model.layers]

print(keras_names)

我还可以从对象检测 API 恢复 tensorflow 检查点并导出具有权重的层:

tf.reset_default_graph()

with tf.Session() as sess:

    new_saver = tf.train.import_meta_graph('models/model.ckpt.meta')

    what = new_saver.restore(sess, 'models/model.ckpt')


    tf_names = []
    for op in sess.graph.get_operations():
        if "MobilenetV2" in op.name and "Assign" in op.name:
            tf_names.append(op.name)

    print(tf_names)

我似乎无法在 keras 和 tensorflow 的层名称之间找到很好的匹配。即使我可以,我也不确定接下来的步骤。

如果有人能给我一些关于解决这个问题的最佳方法的建议,我将不胜感激。

更新:

我遵循了 Sharky 的建议,稍作修改:

new_saver = tf.train.import_meta_graph(os.path.join(keras_checkpoint_dir, 'keras_model.ckpt.meta'))

new_saver.restore(sess, os.path.join(keras_checkpoint_dir, tf.train.latest_checkpoint(keras_checkpoint_dir)))

但是不幸的是,我现在收到此错误:

NotFoundError(参见上面的回溯):从检查点恢复失败。这很可能是由于检查点中缺少变量名称或其他图形键。请确保您没有根据检查点更改预期的图表。原始错误:

关键 FeatureExtractor/MobilenetV2/expanded_conv_6/project/BatchNorm/gamma not found in checkpoint [[node save/RestoreV2_295 (defined at :7) = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task :0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2_295/tensor_names, save/RestoreV2_295/shape_and_slices)]] [[{{节点保存/RestoreV2_196/_393}} = _Recvclient_terminated=false, recv_device="/ job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_789_save /RestoreV2_196", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]]

关于如何摆脱这个错误的任何想法?

标签: tensorflowkeras

解决方案


您可以使用tf.keras.estimator.model_to_estimator

estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir=path)
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, os.path.join(path/keras, tf.train.latest_checkpoint(path/keras)))
    print(tf.global_variables())

这应该可以完成这项工作。请注意,它将在最初指定的路径内创建一个子目录。


推荐阅读