首页 > 解决方案 > 预训练的 TensorFlow 模型无效参数错误

问题描述

我正在使用带有预训练 mobilenet_v2 模型的 tensorflow 来完成我的项目,该模型可以在https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md上找到

我想获得隐藏层的值,所以我实现了这个源代码,我得到了一个无效参数错误

if __name__ == '__main__':
    im = Image.open('./sample/maltiz.png')
    im3 = im.resize((300, 300))

    image = np.asarray(im)[:,:,:3]

    model_path = 'models/ssd_mobilenet_v2_coco_2018_03_29/'

    meta_path = os.path.join(model_path, 'model.ckpt.meta')
    model = tf.train.import_meta_graph(meta_path)

    sess = tf.Session()
    model.restore(sess, tf.train.latest_checkpoint(model_path))

    data = np.array([image])
    data = data.astype(np.uint8)

    X = tf.placeholder(tf.uint8, shape=[None, None, None, 3])

    graph = tf.get_default_graph()

    for i in graph.get_operations():
        if "Relu" in i.name:
            print(sess.run(i.values(), feed_dict = { X : data}))

我收到此错误消息

File "load_model.py", line 42, in <module>

    print(sess.run(i.values(), feed_dict = { X : data}))
InvalidArgumentError: You must feed a value for placeholder tensor 'image_tensor' with dtype uint8 and shape [?,?,?,3]

[[node image_tensor (defined at load_model.py:24) ]]

我打印了占位符和数据的形状。

占位符是 uint8 类型的 [?,?,?,3] 并且图像的形状为 [1,300,300,3] 我不知道有什么问题。

它看起来与错误消息上的类型完美匹配。

请让我知道有什么问题。

标签: pythontensorflowmachine-learninginferencepre-trained-model

解决方案


当您加载预定义图并将图恢复到最新的检查点时,图已定义。但是当你这样做时

X = tf.placeholder(tf.uint8, shape=[None, None, None, 3])

您正在图中创建一个额外的节点。并且该节点与您要评估的节点无关,来自的节点graph.get_operations() 不依赖于这个额外的节点,而是依赖于其他节点,并且由于该其他节点没有得到值,因此错误表示无效参数。

正确的方法是从预定义的图中获取要评估的节点所依赖的张量。

im = Image.open('./sample/maltiz.png')
im3 = im.resize((300, 300))

image = np.asarray(im)[:,:,:3]

model_path = 'models/ssd_mobilenet_v2_coco_2018_03_29/'

meta_path = os.path.join(model_path, 'model.ckpt.meta')
model = tf.train.import_meta_graph(meta_path)

sess = tf.Session()
model.restore(sess, tf.train.latest_checkpoint(model_path))

data = np.array([image])
data = data.astype(np.uint8)

graph = tf.get_default_graph()
X = graph.get_tensor_by_name('image_tensor:0')

for i in graph.get_operations():
    if "Relu" in i.name:
        print(sess.run(i.values(), feed_dict = { X : data}))

PS:我自己确实尝试过上述方法,但是有一些 tensorflow(版本 1.13.1)内部错误阻止我评估Relu其名称中包含的所有节点。但是仍然可以通过这种方式评估一些节点。


推荐阅读