首页 > 解决方案 > 如何加快 Keras 和 Tensorflow 中模型的加载速度?

问题描述

假设我有大量训练有素的 keras 模型,有时我必须加载它们并进行预测。我需要以最快的方式加载每个模型并预测数据。我认为最快的解决方案是将它们存储在内存中,但是我想这不是好方法,因为很快 RAM 就会溢出?所以在这一点上,我用这样的东西达到了最高的性能。

K.clear_session()
random_model = load_model('path/to/' + str(random_model))
results = random_model(final_input_row)

此外,我有五个几乎一直在使用的模型,在这种情况下,性能更加重要。我在启动服务器时加载它们,并且我可以不断地访问它们。

graph = tf.get_default_graph()

with graph.as_default():

    constant_model = load_model(
            'path/to/constant_model')

预言:

with graph.as_default():
    results = constant_model(final_input_row)

问题是K.clear_session()我在加载过程中执行期间random_models我从内存中丢失了它们。没有K.clear_session()加载random_models持续时间过长。你有什么想法我该如何解决这个问题?我什至可以使用完全不同的方法。

更新

我尝试做这样的事情:

class LoadModel:
    def __init__(self, path):
        self.path = path
        self.sess  = tf.Session(config=config)
        self.graph = tf.get_default_graph()
        K.set_session(self.sess)
        with self.graph.as_default():
            self.model = load_model(self.path)

    def do_predictions(self, x):
        with self.graph.as_default():
            return self.model.predict(x)

然后当我执行时:

random_model = LoadModel('./path/to/random_model.h5')
results = random_model.do_predictions(final_input_row)

预测数据大约需要 3 秒。random_models在我有很多模型的情况下,这是可以接受的。但是,constant_models当我有五个并且我需要不断访问它们时,它会持续太长时间。到目前为止,我这样做的方式是在启动 Django 服务器的过程中加载模型,并将其存储在内存中,然后我就可以运行results = constant_model.do_predictions(final_input_row)了,而且速度非常快。在我运行之前它可以正常工作random_models,之后在我收到的每个请求中

tensorflow.python.framework.errors_impl.InvalidArgumentError: Tensor lstm_14_input:0, specified in either feed_devices or fetch_devices was not found in the Graph
[24/Jun/2019 10:11:00] "POST /request/ HTTP/1.1" 500 17326
[24/Jun/2019 10:11:02] "GET /model/ HTTP/1.1" 201 471
Exception ignored in: <bound method BaseSession._Callable.__del__ of <tensorflow.python.client.session.BaseSession._Callable object at 0x130069908>>
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1455, in __del__
    self._session._session, self._handle, status)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 528, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: No such callable handle: 140576717540032
[24/Jun/2019 10:11:02] "GET /model/ HTTP/1.1" 201 471
[24/Jun/2019 10:11:07] "GET /model/ HTTP/1.1" 201 471

显然,如果我constant_model = LoadModel('./path/to/constant_model.h5')在每次运行之前运行它可以正常工作,results = constant_model.do_predictions(final_input_row)但是正如我所提到的那样它太慢了。任何想法如何解决这个问题?

更新2

我尝试以这种方式

session  = tf.Session(config=config)
K.set_session(session)

class LoadModel:
    def __init__(self, path):
        self.path = path

        self.graph = tf.get_default_graph()
        with self.graph.as_default():
            self.model = load_model(self.path)

    def do_predictions(self, x):
        with self.graph.as_default():
            return self.model.predict(x)

但我仍然得到TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(1, 128), dtype=float32) is not an element of this graph.

解决方案 以下是我的工作解决方案。如果有人知道加载满足上述要求的模型的更有效方法,我将不胜感激。

class LoadModel:

    def __init__(self, model_path):
        with K.get_session().graph.as_default():
            self.model = load_model(model_path)
            self.model._make_predict_function()
            self.graph = tf.get_default_graph()

    def predict_data(self, data):
        with self.graph.as_default():
            output = self.model.predict(data)
            return output

标签: pythondjangotensorflowkerasdjango-rest-framework

解决方案


正如您所说,您一直在使用所有模型,最好将它们保存在内存中,这样您就不会在每次发出请求时都加载和预测。例如,您应该定义一个加载模型的类,并为不同的模型保留此类的不同实例。这不是经过测试的代码,因此您可能需要进行一些更改。

# if you are using GPU
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

class LoadModel:
    def __init__(self, path_to_model):
        self.path  = path
        self.sess  = tf.Session(config=config)
        self.graph = tf.get_default_graph()
        K.set_session(self.sess) 
        self.model = self.load()

    def load(self):
        with graph.as_default():
            model = load_model(self.path)
        return model

    def do_predictions(self, x):
        return self.model.predict(x)    

推荐阅读