tensorflow - 由于 eval() 函数,Tensorflow 推理变得越来越慢
问题描述
所以我有一个冻结的张量流模型,可以用来对图像进行分类。当我尝试使用此模型逐个推断图像时,模型运行速度越来越慢。我搜索并发现问题可能是由 eval() 函数引起的,它会不断向图中添加新节点,从而减慢过程。
以下是我的代码的关键部分:
with open('/tmp/frozen_resnet_v1_50.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
sess1 = tf.Session()
sess = tf.Session()
for root, dirs, files in os.walk(file_path):
for f in files:
# Read image one by one and preprocess
img = cv2.imread(os.path.join(root, f))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR 2 RGB
img = image_preprocessing_fn(img, _IMAGE_HEIGHT, _IMAGE_WIDTH) # This function contains tf functions
img = img.eval(session=sess1)
img = np.reshape(img, [-1, _IMAGE_HEIGHT, _IMAGE_WIDTH, _IMAGE_CHANNEL]) # the input shape is 4 dimension
# Feed image to model
data = sess.graph.get_tensor_by_name('input:0')
predict = sess.graph.get_tensor_by_name('resnet_v1_50/predictions/Softmax:0')
out = sess.run(predict, feed_dict={data: img})
indices = np.argmax(out, 1)
print('Current image name: %s, predict result: %s' % (f, indices))
sess1.close()
sess.close()
PS:我用“sess1”做预处理,我觉得可能不合适。希望有人能告诉我正确的方法,在此先感谢。
解决方案
没有人回答......这是我的解决方案,它有效!
with open('/tmp/frozen_resnet_v1_50.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
x = tf.placeholder(tf.uint8, shape=[None, None, 3])
y = image_preprocessing_fn(x, _IMAGE_HEIGHT, _IMAGE_WIDTH)
sess = tf.Session()
data = sess.graph.get_tensor_by_name('input:0')
predict = sess.graph.get_tensor_by_name('resnet_v1_50/predictions/Softmax:0')
for root, dirs, files in os.walk(file_path):
for f in files:
img = cv2.imread(os.path.join(root, f))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR 2 RGB
img = sess.run(y, feed_dict={x: img})
img = np.reshape(img, [-1, _IMAGE_HEIGHT, _IMAGE_WIDTH, _IMAGE_CHANNEL])
out = sess.run(predict, feed_dict={data: img})
indices = np.argmax(out, 1)
print('Current image name: %s, predict result: %s' % (f, out))
sess.close()
推荐阅读
- angularjs - Angular 在更改状态后显示 2 个视图
- angular - Rx Observable 管道处理错误
- scala - 实施 Supervision.Resume 时的 Akka Stream 测试流程
- debian - 查找 debian jessie 的旧分支
- node.js - Socket.io 优化:socket.io 是否发送 JSON 字符串?我们如何优化它?
- html - 占宽度 100% 的复选框
- java - gcj 错误 - 找不到类 java.util.function.Predicate 的文件
- vue.js - 如果孩子在vue中有课程,如何设置父母的样式?
- vba - 使用标准复制和粘贴值 vba
- html - 变暗的图像问题css