python - Tensorflow:使用线程池进行多 CPU 推理
问题描述
我有很多要并行处理的图像。
默认情况下,Tensorflow 可以使用多个内核,这里有一些关于此 https://stackoverflow.com/a/41233901/1179925的信息
“目前,这意味着每个线程池将在您的机器中每个 CPU 核心拥有一个线程。”
通过查看 htop,我可以看到在此默认设置中并非所有内核都以 100% 使用,因此我想并行设置intra_op_parallelism_threads=1
和inter_op_parallelism_threads=1
运行n_cpu
模型,但它的性能更差。
在我的 8 核笔记本上:
单核顺序处理:
Model init time: 0.77 sec
Processing time: 37.58 sec
多 CPU 默认 TensorFlow 设置:
Model init time: 0.76 sec
Processing time: 20.16 sec
此代码使用多处理:
Model init time: 0.78 sec
Processing time: 39.14 sec
这是我的代码使用multiprocessing
,我错过了什么?:
import os
import glob
import time
import argparse
from multiprocessing.pool import ThreadPool
import multiprocessing
import itertools
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import cv2
MODEL_FILEPATH = './tensorflow_example/inception_v3_2016_08_28_frozen.pb'
def get_image_filepaths(dataset_dir):
if not os.path.isdir(dataset_dir):
raise Exception(dataset_dir, 'not dir!')
img_filepaths = []
extensions = ['**/*.jpg', '**/*.png', '**/*.JPG', '**/*.PNG']
for ext in extensions:
img_filepaths.extend(glob.iglob(os.path.join(dataset_dir, ext), recursive=True))
return img_filepaths
class ModelWrapper():
def __init__(self, model_filepath):
# TODO: estimate this from graph itself
# Hardcoded for inception_v3_2016_08_28_frozen.pb
self.input_node_names = ['input']
self.output_node_names = ['InceptionV3/Predictions/Reshape_1']
self.input_img_w = 299
self.input_img_h = 299
input_tensor_names = [name + ":0" for name in self.input_node_names]
output_tensor_names = [name + ":0" for name in self.output_node_names]
self.graph = self.load_graph(model_filepath)
self.inputs = []
for input_tensor_name in input_tensor_names:
self.inputs.append(self.graph.get_tensor_by_name(input_tensor_name))
self.outputs = []
for output_tensor_name in output_tensor_names:
self.outputs.append(self.graph.get_tensor_by_name(output_tensor_name))
config_proto = tf.ConfigProto(device_count={'GPU': 0},
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)
self.sess = tf.Session(graph=self.graph, config=config_proto)
def load_graph(self, model_filepath):
# Expects frozen graph in .pb format
with tf.gfile.GFile(model_filepath, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name="")
return graph
def predict(self, img):
h, w, c = img.shape
if h != self.input_img_h or w != self.input_img_w:
img = cv2.resize(img, (self.input_img_w, self.input_img_h))
batch = img[np.newaxis, ...]
feed_dict = {self.inputs[0] : batch}
outputs = self.sess.run(self.outputs, feed_dict=feed_dict) # (1, 1001)
return outputs
def process_single_file(args):
model, img_filepath = args
img = cv2.imread(img_filepath)
output = model.predict(img)
def process_dataset(dataset_dir):
img_filepaths = get_image_filepaths(dataset_dir)
start = time.time()
model = ModelWrapper(MODEL_FILEPATH)
print('Model init time:', round(time.time() - start, 2), 'sec')
start = time.time()
n_cpu = multiprocessing.cpu_count()
for _ in tqdm(ThreadPool(n_cpu).imap_unordered(process_single_file,
zip(itertools.repeat(model), img_filepaths)),
total=len(img_filepaths)):
pass
print('Processing time:', round(time.time() - start, 2), 'sec')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(dest='dataset_dir')
args = parser.parse_args()
process_dataset(args.dataset_dir)
更新:
替换multiprocessing.pool.ThreadPool
为multiprocessing.Pool
:
def process_dataset(dataset_dir):
img_filepaths = get_image_filepaths(dataset_dir)
start = time.time()
model = ModelWrapper(MODEL_FILEPATH)
print('Model init time:', round(time.time() - start, 2), 'sec')
start = time.time()
n_cpu = multiprocessing.cpu_count()
pool = multiprocessing.Pool(n_cpu)
it = pool.imap_unordered(process_single_file, zip(itertools.repeat(model), img_filepaths))
for _ in tqdm(it, total=len(img_filepaths)):
pass
print('Processing time:', round(time.time() - start, 2), 'sec')
我收到一个错误:
Traceback (most recent call last):
File "tensorflow_example/multi_core_cpu_inference_multiprocessing.py", line 110, in <module>
process_dataset(args.dataset_dir)
File "tensorflow_example/multi_core_cpu_inference_multiprocessing.py", line 99, in process_dataset
for _ in tqdm(it, total=len(img_filepaths)):
File "/usr/local/lib/python3.6/site-packages/tqdm/_tqdm.py", line 979, in __iter__
for obj in iterable:
File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/multiprocessing/pool.py", line 735, in next
raise value
File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/multiprocessing/pool.py", line 424, in _handle_tasks
put(task)
File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/multiprocessing/connection.py", line 206, in send
self._send_bytes(_ForkingPickler.dumps(obj))
File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
TypeError: can't pickle _thread.RLock objects
解决方案
基于此答案:https ://stackoverflow.com/a/46779776/1179925
它可以工作,但并不比 tensorflow 本身提供的默认并行性快多少。
import os
import glob
import time
import argparse
import multiprocessing
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import cv2
# Running N_PROCESSES processes using multiprocessing pool
N_PROCESSES = 2
N_CPU = multiprocessing.cpu_count()
INTRA_N_THREADS = max(1, N_CPU // N_PROCESSES)
INTER_N_THREADS = max(1, N_CPU // N_PROCESSES)
print('N_PROCESSES', N_PROCESSES)
print('N_CPU', N_CPU)
print('INTRA_N_THREADS', INTRA_N_THREADS)
print('INTER_N_THREADS', INTER_N_THREADS)
MODEL_FILEPATH = './tensorflow_example/inception_v3_2016_08_28_frozen.pb'
def get_image_filepaths(dataset_dir):
if not os.path.isdir(dataset_dir):
raise Exception(dataset_dir, 'not dir!')
img_filepaths = []
extensions = ['**/*.jpg', '**/*.png', '**/*.JPG', '**/*.PNG']
for ext in extensions:
img_filepaths.extend(glob.iglob(os.path.join(dataset_dir, ext), recursive=True))
return img_filepaths
class ModelWrapper():
def __init__(self, model_filepath):
# TODO: estimate this from graph itself
# Hardcoded for inception_v3_2016_08_28_frozen.pb
self.input_node_names = ['input']
self.output_node_names = ['InceptionV3/Predictions/Reshape_1']
self.input_img_w = 299
self.input_img_h = 299
input_tensor_names = [name + ":0" for name in self.input_node_names]
output_tensor_names = [name + ":0" for name in self.output_node_names]
self.graph = self.load_graph(model_filepath)
self.inputs = []
for input_tensor_name in input_tensor_names:
self.inputs.append(self.graph.get_tensor_by_name(input_tensor_name))
self.outputs = []
for output_tensor_name in output_tensor_names:
self.outputs.append(self.graph.get_tensor_by_name(output_tensor_name))
config_proto = tf.ConfigProto(device_count={'GPU': 0},
intra_op_parallelism_threads=INTRA_N_THREADS,
inter_op_parallelism_threads=INTER_N_THREADS)
self.sess = tf.Session(graph=self.graph, config=config_proto)
def load_graph(self, model_filepath):
# Expects frozen graph in .pb format
with tf.gfile.GFile(model_filepath, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name="")
return graph
def predict(self, img):
h, w, c = img.shape
if h != self.input_img_h or w != self.input_img_w:
img = cv2.resize(img, (self.input_img_w, self.input_img_h))
batch = img[np.newaxis, ...]
feed_dict = {self.inputs[0] : batch}
outputs = self.sess.run(self.outputs, feed_dict=feed_dict) # (1, 1001)
return outputs
def process_chunk(img_filepaths):
start = time.time()
model = ModelWrapper(MODEL_FILEPATH)
print('Model init time:', round(time.time() - start, 2), 'sec')
for img_filepath in img_filepaths:
img = cv2.imread(img_filepath)
output = model.predict(img)
def process_dataset(dataset_dir):
img_filepaths = get_image_filepaths(dataset_dir)
start = time.time()
pool = multiprocessing.Pool(N_PROCESSES)
chunks = []
n = len(img_filepaths) // N_PROCESSES
for i in range(0, len(img_filepaths), n):
chunk = img_filepaths[i:i+n]
chunks.append(chunk)
it = pool.imap_unordered(process_chunk, chunks)
for _ in tqdm(it, total=len(img_filepaths)):
pass
print('Processing time:', round(time.time() - start, 2), 'sec')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(dest='dataset_dir')
args = parser.parse_args()
process_dataset(args.dataset_dir)
推荐阅读
- sqlite - SQLite中的普通索引和FTS有什么区别
- python - 尽管我有一个具有特定值的 return 语句,但我的函数返回 None
- pkg-config - 如何在 Windows 10 上使用 vcpkg 安装 pkg-config、sigc++-2.0
- java - JPA @OrderBy 不对关系进行排序
- javascript - 从最大到最小的reactjs映射特定数组
- python-3.x - 使用多个输入进行测试时字符串切片的意外输出
- php - 注册时未在我的数据库中收到我的 php 表单数据
- python - 我试图在pycharm的python中制作一个简单的密码程序,当我给出if语句时它给出了一个错误
- ios - 我想通过点击按钮添加文本字段
- android - 如何使行项目适合屏幕