首页 > 解决方案 > 由于 eval() 函数,Tensorflow 推理变得越来越慢

问题描述

所以我有一个冻结的张量流模型,可以用来对图像进行分类。当我尝试使用此模型逐个推断图像时,模型运行速度越来越慢。我搜索并发现问题可能是由 eval() 函数引起的,它会不断向图中添加新节点,从而减慢过程。

以下是我的代码的关键部分:

with open('/tmp/frozen_resnet_v1_50.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

    sess1 = tf.Session()
    sess = tf.Session()

    for root, dirs, files in os.walk(file_path):
        for f in files:
            # Read image one by one and preprocess
            img = cv2.imread(os.path.join(root, f))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # BGR 2 RGB

            img = image_preprocessing_fn(img, _IMAGE_HEIGHT, _IMAGE_WIDTH)  # This function contains tf functions
            img = img.eval(session=sess1)

            img = np.reshape(img, [-1, _IMAGE_HEIGHT, _IMAGE_WIDTH, _IMAGE_CHANNEL])    # the input shape is 4 dimension

            # Feed image to model
            data = sess.graph.get_tensor_by_name('input:0')
            predict = sess.graph.get_tensor_by_name('resnet_v1_50/predictions/Softmax:0')

            out = sess.run(predict, feed_dict={data: img})
            indices = np.argmax(out, 1)

            print('Current image name: %s, predict result: %s' % (f, indices))

    sess1.close()
    sess.close()

PS:我用“sess1”做预处理,我觉得可能不合适。希望有人能告诉我正确的方法,在此先感谢。

标签: tensorflowinference

解决方案


没有人回答......这是我的解决方案,它有效!

with open('/tmp/frozen_resnet_v1_50.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

    x = tf.placeholder(tf.uint8, shape=[None, None, 3])
    y = image_preprocessing_fn(x, _IMAGE_HEIGHT, _IMAGE_WIDTH)

    sess = tf.Session()
    data = sess.graph.get_tensor_by_name('input:0')
    predict = sess.graph.get_tensor_by_name('resnet_v1_50/predictions/Softmax:0')

    for root, dirs, files in os.walk(file_path):
        for f in files:
            img = cv2.imread(os.path.join(root, f))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # BGR 2 RGB
            img = sess.run(y, feed_dict={x: img})

            img = np.reshape(img, [-1, _IMAGE_HEIGHT, _IMAGE_WIDTH, _IMAGE_CHANNEL])

            out = sess.run(predict, feed_dict={data: img})
            indices = np.argmax(out, 1)

            print('Current image name: %s, predict result: %s' % (f, out))

    sess.close()

推荐阅读