python - 将图像预测存储到 csv 的 Python for 循环启动很快但速度变慢
问题描述
我有一个包含超过 100,000 张图像的文件夹,我想使用 TensorFlow 对其进行分类。我编写了一个 for 循环,它遍历每个图像,返回一个置信度分数,并将预测结果存储到一个 csv 文件中。
问题是:脚本启动非常快(图像 1-1000 每秒大约 10 张图像),并且随着每次迭代逐渐减慢(图像> 1000 每秒只有大约 1 张图像)。
对于 Python 中 for 循环的类似减速问题,我读到预分配可能是一种解决方案。但是,我直接写入 csv 而不是列表,所以我不确定这应该如何帮助。
有什么办法可以确保整个循环过程中的速度一致?
提前感谢您的任何指点!
请在下面找到我的代码,该代码基于本教程(https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0):
filename = "predictions.csv"
f = open(filename, "w")
headers = "id;image_name;confidence\n"
f.write(headers)
start = 1
end = 20000
testdata = "C:/files/"
files = list(os.listdir(testdata))
for index in range(start, end+1):
filename = files[index]
if not filename.startswith('.'):
print(str(index) + " - " + str(filename))
image=testdata+filename
results = label_image(image, graph, session, input_height=299, input_width=299, input_layer="Mul")
f.write(str(index) + ";" + str(filename) + ";" + str(results[0]) + "\n")
print("\n")
f.close()
编辑:
在运行循环之前,我只加载了一次图表。
from scripts.label_image import load_graph, label_image, get_session
model_file = "retrained_graph.pb"
graph = load_graph(model_file)
session = get_session(graph)
编辑2:
这是 label_image 函数的代码。
def label_image(file_name, graph, session, label_file="retrained_labels.txt", input_height=224, input_width=224, input_mean=128, input_std=128, input_layer="input", output_layer="final_result"):
t = read_tensor_from_image_file(file_name,
input_height=input_height,
input_width=input_width,
input_mean=input_mean,
input_std=input_std)
input_name = "import/" + input_layer
output_name = "import/" + output_layer
input_operation = graph.get_operation_by_name(input_name);
output_operation = graph.get_operation_by_name(output_name);
start = time.time()
results = session.run(output_operation.outputs[0],
{input_operation.outputs[0]: t})
end=time.time()
results = np.squeeze(results)
top_k = results.argsort()[-5:][::-1]
labels = load_labels(label_file)
print('\nEvaluation time (1-image): {:.3f}s\n'.format(end-start))
template = "{} (score={:0.5f})"
for i in top_k:
print(template.format(labels[i], results[i]))
return results
编辑 3:
这是 read_tensor_from_image_file 函数的代码。
def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
input_mean=0, input_std=255):
input_name = "file_reader"
output_name = "normalized"
file_reader = tf.read_file(file_name, input_name)
if file_name.endswith(".png"):
image_reader = tf.image.decode_png(file_reader, channels = 3,
name='png_reader')
elif file_name.endswith(".gif"):
image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
name='gif_reader'))
elif file_name.endswith(".bmp"):
image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
else:
image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
name='jpeg_reader')
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0);
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
sess = tf.Session()
result = sess.run(normalized)
return result
编辑4:
这是我重构的代码,它向我抛出了错误: AttributeError: 'Tensor' object has no attribute 'endswith'
def process_image(file_name):
input_name = "file_reader"
output_name = "normalized"
file_reader = tf.read_file(file_name, input_name)
if file_name.endswith(".png"):
image_reader = tf.image.decode_png(file_reader, channels = 3,
name='png_reader')
elif file_name.endswith(".gif"):
image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
name='gif_reader'))
elif file_name.endswith(".bmp"):
image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
else:
image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
name='jpeg_reader')
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0);
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
return normalized
filename_placeholder = tf.placeholder(tf.string)
processed = process_image(filename_placeholder)
def label_image(file_name, graph, session, label_file="tf_files/retrained_labels.txt", input_height=224, input_width=224, input_mean=128, input_std=128, input_layer="input", output_layer="final_result"):
result = sess.run(processed, feed_dict={filename_placeholder: file_name})
input_name = "import/" + input_layer
output_name = "import/" + output_layer
input_operation = graph.get_operation_by_name(input_name);
output_operation = graph.get_operation_by_name(output_name);
start = time.time()
results = session.run(output_operation.outputs[0],
{input_operation.outputs[0]: t})
end=time.time()
results = np.squeeze(results)
top_k = results.argsort()[-5:][::-1]
labels = load_labels(label_file)
print('\nEvaluation time (1-image): {:.3f}s\n'.format(end-start))
template = "{} (score={:0.5f})"
for i in top_k:
print(template.format(labels[i], results[i]))
return results
解决方案
问题出在read_tensor_from_image_file
函数内部。在循环的每次迭代中都会调用此函数。在函数中,您正在创建 Tensorflow 操作。根据经验,tf.anything
调用负责构建计算图。它们应该只被调用一次,然后使用tf.Session
. 实际上,您会不断地使用相同图像处理操作的“克隆”来增加计算图的大小,这会随着您的图变大而逐渐减慢执行速度。
您应该重构您的代码,以便其中的操作定义read_tensor_from_image_file
只执行一次,并且只执行sess.run(normalized)
循环内的部分。您可以使用 atf.placeholder
作为输入(文件名)。此外,您不应该在每次调用该函数时都创建一个新会话——而是从label_image
.
这是一个简化的示例,说明如何像这样重构代码。假设我们有一个创建图像处理操作的函数:
def process_image(file_name):
file_reader = tf.read_file(file_name, input_name)
...
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
return normalized
read_tensor_from_image_file
除了最后一部分涉及会话之外,这基本上是您的功能。你现在做的基本上是
def label_image(file_name, ...):
processed = process_image(file_name)
sess = tf.Session()
result = sess.run(processed)
....
for file_name in files:
label_image(file_name, ...)
相反,你应该做的是
filename_placeholder = tf.placeholder(tf.string)
processed = process_image(filename_placeholder)
def label_image(file_name, ...):
result = sess.run(processed, feed_dict={filename_placeholder: file_name})
....
for file_name in files:
label_image(file_name, ...)
重要的区别是我们将process_image
调用移出循环,仅run
将其移入循环。此外,我们不会连续创建新会话。全局变量有点恶心,但你应该明白了。
我唯一不确定的是你是否可以使用你得到的会话get_session(graph)
来运行processed
张量。如果这不起作用(即崩溃),那么您将需要创建第二个会话来运行这些东西,但是您应该只在调用后执行一次process_image
,而不是在循环内重复执行此操作。
推荐阅读
- reactjs - 如何在反应中使用 Link 和路由器将道具从一个组件传递到另一个组件
- python - 我可以更改 matplotlib 为值呈现 r 标签的方式吗?
- python - `tf.distribute.MirroredStrategy` 对训练结果有影响吗?
- flutter - 支持的最低 Gradle 版本为 5.6.4。当前版本是 5.6.2。过去 4 小时我一直在尝试解决 Flutter 中的此错误,但无法解决
- r - 从 GitHub 读取数据并将行拆分为 R 中的几行
- python - pywintypes在密钥存在时给出类未注册错误
- angular - CKEditor 5在Angular中调整图像大小
- linux - Yum update producing error on a package - so no packages update
- javascript - 在 Android 中运行无头 JavaScript 的有效方法
- python - 如何在python中将两个客户端连接到一台服务器?