首页 > 解决方案 > Tensorflow:使用线程池进行多 CPU 推理

问题描述

我有很多要并行处理的图像。

默认情况下,Tensorflow 可以使用多个内核,这里有一些关于此 https://stackoverflow.com/a/41233901/1179925的信息

“目前,这意味着每个线程池将在您的机器中每个 CPU 核心拥有一个线程。”

通过查看 htop,我可以看到在此默认设置中并非所有内核都以 100% 使用,因此我想并行设置intra_op_parallelism_threads=1inter_op_parallelism_threads=1运行n_cpu模型,但它的性能更差。

在我的 8 核笔记本上:

单核顺序处理:

Model init time: 0.77 sec
Processing time: 37.58 sec

多 CPU 默认 TensorFlow 设置:

Model init time: 0.76 sec
Processing time: 20.16 sec

此代码使用多处理:

Model init time: 0.78 sec
Processing time: 39.14 sec

这是我的代码使用multiprocessing,我错过了什么?:

import os
import glob
import time
import argparse
from multiprocessing.pool import ThreadPool
import multiprocessing
import itertools

import tensorflow as tf
import numpy as np
from tqdm import tqdm
import cv2

MODEL_FILEPATH = './tensorflow_example/inception_v3_2016_08_28_frozen.pb'

def get_image_filepaths(dataset_dir):
    if not os.path.isdir(dataset_dir):
        raise Exception(dataset_dir, 'not dir!')

    img_filepaths = []
    extensions = ['**/*.jpg', '**/*.png', '**/*.JPG', '**/*.PNG']
    for ext in extensions:
        img_filepaths.extend(glob.iglob(os.path.join(dataset_dir, ext), recursive=True))

    return img_filepaths


class ModelWrapper():
    def __init__(self, model_filepath):
        # TODO: estimate this from graph itself
        # Hardcoded for inception_v3_2016_08_28_frozen.pb
        self.input_node_names = ['input']
        self.output_node_names = ['InceptionV3/Predictions/Reshape_1']
        self.input_img_w = 299
        self.input_img_h = 299

        input_tensor_names = [name + ":0" for name in self.input_node_names]
        output_tensor_names = [name + ":0" for name in self.output_node_names]

        self.graph = self.load_graph(model_filepath)

        self.inputs = []
        for input_tensor_name in input_tensor_names:
            self.inputs.append(self.graph.get_tensor_by_name(input_tensor_name))

        self.outputs = []
        for output_tensor_name in output_tensor_names:
            self.outputs.append(self.graph.get_tensor_by_name(output_tensor_name))

        config_proto = tf.ConfigProto(device_count={'GPU': 0},
                                      intra_op_parallelism_threads=1,
                                      inter_op_parallelism_threads=1)
        self.sess = tf.Session(graph=self.graph, config=config_proto)

    def load_graph(self, model_filepath):
        # Expects frozen graph in .pb format
        with tf.gfile.GFile(model_filepath, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def, name="")
        return graph

    def predict(self, img):
        h, w, c = img.shape
        if h != self.input_img_h or w != self.input_img_w:
            img = cv2.resize(img, (self.input_img_w, self.input_img_h))

        batch = img[np.newaxis, ...]
        feed_dict = {self.inputs[0] : batch}
        outputs = self.sess.run(self.outputs, feed_dict=feed_dict) # (1, 1001)

        return outputs


def process_single_file(args):
    model, img_filepath = args

    img = cv2.imread(img_filepath)
    output = model.predict(img)


def process_dataset(dataset_dir):
    img_filepaths = get_image_filepaths(dataset_dir)

    start = time.time()
    model = ModelWrapper(MODEL_FILEPATH)
    print('Model init time:', round(time.time() - start, 2), 'sec')

    start = time.time()
    n_cpu = multiprocessing.cpu_count()
    for _ in tqdm(ThreadPool(n_cpu).imap_unordered(process_single_file,
                                                   zip(itertools.repeat(model), img_filepaths)),
                                                   total=len(img_filepaths)):
        pass
    print('Processing time:', round(time.time() - start, 2), 'sec')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(dest='dataset_dir')
    args = parser.parse_args()

    process_dataset(args.dataset_dir)

更新:

替换multiprocessing.pool.ThreadPoolmultiprocessing.Pool

def process_dataset(dataset_dir):
    img_filepaths = get_image_filepaths(dataset_dir)

    start = time.time()
    model = ModelWrapper(MODEL_FILEPATH)
    print('Model init time:', round(time.time() - start, 2), 'sec')

    start = time.time()
    n_cpu = multiprocessing.cpu_count()
    pool = multiprocessing.Pool(n_cpu)

    it = pool.imap_unordered(process_single_file, zip(itertools.repeat(model), img_filepaths))
    for _ in tqdm(it, total=len(img_filepaths)):
        pass

    print('Processing time:', round(time.time() - start, 2), 'sec')

我收到一个错误:

Traceback (most recent call last):
  File "tensorflow_example/multi_core_cpu_inference_multiprocessing.py", line 110, in <module>
    process_dataset(args.dataset_dir)
  File "tensorflow_example/multi_core_cpu_inference_multiprocessing.py", line 99, in process_dataset
    for _ in tqdm(it, total=len(img_filepaths)):
  File "/usr/local/lib/python3.6/site-packages/tqdm/_tqdm.py", line 979, in __iter__
    for obj in iterable:
  File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/multiprocessing/pool.py", line 735, in next
    raise value
  File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/multiprocessing/pool.py", line 424, in _handle_tasks
    put(task)
  File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/multiprocessing/connection.py", line 206, in send
    self._send_bytes(_ForkingPickler.dumps(obj))
  File "/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
TypeError: can't pickle _thread.RLock objects

标签: pythontensorflowmultiprocessingthreadpoolbatch-processing

解决方案


基于此答案:https ://stackoverflow.com/a/46779776/1179925

它可以工作,但并不比 tensorflow 本身提供的默认并行性快多少。

import os
import glob
import time
import argparse
import multiprocessing

import tensorflow as tf
import numpy as np
from tqdm import tqdm
import cv2

# Running N_PROCESSES processes using multiprocessing pool

N_PROCESSES = 2
N_CPU = multiprocessing.cpu_count()
INTRA_N_THREADS = max(1, N_CPU // N_PROCESSES)
INTER_N_THREADS = max(1, N_CPU // N_PROCESSES)

print('N_PROCESSES', N_PROCESSES)
print('N_CPU', N_CPU)
print('INTRA_N_THREADS', INTRA_N_THREADS)
print('INTER_N_THREADS', INTER_N_THREADS)

MODEL_FILEPATH = './tensorflow_example/inception_v3_2016_08_28_frozen.pb'

def get_image_filepaths(dataset_dir):
    if not os.path.isdir(dataset_dir):
        raise Exception(dataset_dir, 'not dir!')

    img_filepaths = []
    extensions = ['**/*.jpg', '**/*.png', '**/*.JPG', '**/*.PNG']
    for ext in extensions:
        img_filepaths.extend(glob.iglob(os.path.join(dataset_dir, ext), recursive=True))

    return img_filepaths


class ModelWrapper():
    def __init__(self, model_filepath):
        # TODO: estimate this from graph itself
        # Hardcoded for inception_v3_2016_08_28_frozen.pb
        self.input_node_names = ['input']
        self.output_node_names = ['InceptionV3/Predictions/Reshape_1']
        self.input_img_w = 299
        self.input_img_h = 299

        input_tensor_names = [name + ":0" for name in self.input_node_names]
        output_tensor_names = [name + ":0" for name in self.output_node_names]

        self.graph = self.load_graph(model_filepath)

        self.inputs = []
        for input_tensor_name in input_tensor_names:
            self.inputs.append(self.graph.get_tensor_by_name(input_tensor_name))

        self.outputs = []
        for output_tensor_name in output_tensor_names:
            self.outputs.append(self.graph.get_tensor_by_name(output_tensor_name))

        config_proto = tf.ConfigProto(device_count={'GPU': 0},
                                      intra_op_parallelism_threads=INTRA_N_THREADS,
                                      inter_op_parallelism_threads=INTER_N_THREADS)
        self.sess = tf.Session(graph=self.graph, config=config_proto)

    def load_graph(self, model_filepath):
        # Expects frozen graph in .pb format
        with tf.gfile.GFile(model_filepath, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def, name="")
        return graph

    def predict(self, img):
        h, w, c = img.shape
        if h != self.input_img_h or w != self.input_img_w:
            img = cv2.resize(img, (self.input_img_w, self.input_img_h))

        batch = img[np.newaxis, ...]
        feed_dict = {self.inputs[0] : batch}
        outputs = self.sess.run(self.outputs, feed_dict=feed_dict) # (1, 1001)

        return outputs


def process_chunk(img_filepaths):
    start = time.time()
    model = ModelWrapper(MODEL_FILEPATH)
    print('Model init time:', round(time.time() - start, 2), 'sec')

    for img_filepath in img_filepaths:
        img = cv2.imread(img_filepath)
        output = model.predict(img)


def process_dataset(dataset_dir):
    img_filepaths = get_image_filepaths(dataset_dir)

    start = time.time()
    pool = multiprocessing.Pool(N_PROCESSES)

    chunks = []
    n = len(img_filepaths) // N_PROCESSES
    for i in range(0, len(img_filepaths), n):
        chunk = img_filepaths[i:i+n]
        chunks.append(chunk)

    it = pool.imap_unordered(process_chunk, chunks)
    for _ in tqdm(it, total=len(img_filepaths)):
        pass

    print('Processing time:', round(time.time() - start, 2), 'sec')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(dest='dataset_dir')
    args = parser.parse_args()

    process_dataset(args.dataset_dir)

推荐阅读