首页 > 解决方案 > 如何正确加载和运行 2 个模型?

问题描述

我正在尝试加载一个 keras 模型和 2 个 tf 模型并在一段代码中按顺序运行它们。但是它不起作用。我怎样才能正确使用它们?

环境:Python-3.5,Tensorflow-1.14.0

这是我的代码框架

load keras_model
load tf_model_1
load tf_model_2
predict using keras_model
predict using tf_model_1
predict using tf_model_2

关于加载和使用 tf_model 的详细信息

g1 = tf.Graph()
sess_1 = tf.Session(graph=g1, config=get_config())


with sess_1.as_default():
    with sess_1.graph.as_default():
        print(tf.get_default_graph())
        output = build_model()
        init_fn = slim.assign_from_checkpoint_fn(
            'model1.ckpt',
            slim.get_model_variables(),
            ignore_missing_vars=True)
        init_op = tf.global_variables_initializer()
        sess_1.run(init_op)
        init_fn(sess_1)
###########################
g2 = tf.Graph()
sess_2= tf.Session(graph=g2, config=get_config())


with sess_2.as_default():
    with sess_2.graph.as_default():
        print(tf.get_default_graph())
        output = build_model()
        init_fn = slim.assign_from_checkpoint_fn(
            'model2.ckpt',
            slim.get_model_variables(),
            ignore_missing_vars=True)
        init_op = tf.global_variables_initializer()
        sess_2.run(init_op)
        init_fn(sess_2)

使用 tf_model 时

with sess_1.as_default():
    with sess_1.graph.as_default():
        out = sess_1.run(
            [output], feed_dict={inputs: x})

3个模型都可以独立正常工作,但是当我尝试按我的顺序组合时,运行tf_model_1时报错:

TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(?, 224, 224, 3), dtype=float32) is not an element of this graph.

我已经尝试将它们与这个序列结合起来:

1.load keras_model
2.load tf_model_1
3.load tf_model_2
4.predict using keras_model
5.predict using tf_model_2
6.predict using tf_model_1

它可以运行到第6步并抛出错误。所以我怀疑这是2个不同的张量流图的问题。我在加载和运行tf_model时加入了这一行

print(tf.get_default_graph())

我得到了输出

loading tf_model_1:
<tensorflow.python.framework.ops.Graph object at 0x7fca53563550>
loading tf_model_2:
<tensorflow.python.framework.ops.Graph object at 0x7fca53559ba8>
runnng tf_model_1:
<tensorflow.python.framework.ops.Graph object at 0x7fca53563550>

运行相应模型时似乎设置了正确的tf图,这让我很困惑。同样没有第二个 tf_model,我连接了 keras 模型和一个 tf 模型。它可以工作。

1.load keras_model
2.load tf_model_1

3.predict using keras_model
4.predict using tf_model_1

我怎样才能正确使用它们?

标签: pythontensorflow

解决方案


推荐阅读