python - 为什么在 Keras 中使用多处理时推理时间会变慢?
问题描述
我想要几个进程,每个进程一次加载一个不同的图像并执行推理(例如 VGG16)。
我正在使用带有 tensorFlow 后端的 Keras,一个 GPU(GTX 1070)。以下是代码:
import tensorflow as tf
import multiprocessing
from multiprocessing import Pool, Process, Queue
import os
from os.path import isfile, join
from PIL import Image
import time
from keras.applications.vgg16 import VGG16
import numpy as np
from keras.backend.tensorflow_backend import set_session
test_path = 'test path to images ...'
output = Queue()
def worker(file_names, output):
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.25
config.gpu_options.visible_device_list = "0"
set_session(tf.Session(config=config))
inference_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3), pooling='avg')
model_image_size = (224,224)
times = []
for file_name in file_names:
image = Image.open(os.path.join(test_path, file_name))
im_width = image.size[0]
im_height = image.size[1]
m = (im_width - im_height) // 2
image = image.crop((m, 0, im_width - m, im_height))
image = image.resize((model_image_size), Image.BICUBIC)
image = np.array(image, dtype='float32')
image /= 255.
image = np.expand_dims(image, 0) # Add batch dimension.
start = time.time()
res = inference_model.predict(image)
end = time.time()
elapsed_time = end - start
print("elapsed time", elapsed_time)
times.append(elapsed_time)
average_time = np.mean(times[2:])
print("average time ", average_time)
if __name__ == '__main__':
file_names = [f for f in os.listdir(test_path) if isfile(join(test_path, f))]
file_names.sort()
num_workers = 3
processes = [Process(target=worker, args=(file_names[x::num_workers], output)) for x in range(num_workers)]
for p in processes:
p.start()
for p in processes:
p.join()
我注意到与单进程相比,多进程的每个图像的推理经过时间较慢。例如,对于单个图像,推理经过的时间是 0.012 秒。当运行 3 个进程时,我希望得到相同的结果,但是,每张图像的平均推理时间几乎是 0.02 秒。这可能是什么原因?(也许 CUDA 上下文 - 切换?)有没有办法解决这个问题?
解决方案
推荐阅读
- jquery - 循环通过文件客户端进行文件上传控制
- git - 如何在 Git 触发器运行之前检查第三方网站的更新?
- java - 我正在使用条带支付 java 实现。没有此类费用的条纹退款存在成功付款的例外情况
- r - sp_execute_external_script R 脚本'无法启动 png() 设备'
- node.js - Node.js - 如何从数组中检索值?
- html - Angular ng-select 多选复选框
- java - java中匹配字符串所需的正则表达式模式以“{{”开头并以“}}”结尾
- css - 如何将动态 TailwindCSS 类添加到 React 中的 DOM 元素
- reactjs - 一次向地图状态对象添加值
- haskell - Haskell 用新书替换给定的现有书