python - 如何加快 Keras 和 Tensorflow 中模型的加载速度?
问题描述
假设我有大量训练有素的 keras 模型,有时我必须加载它们并进行预测。我需要以最快的方式加载每个模型并预测数据。我认为最快的解决方案是将它们存储在内存中,但是我想这不是好方法,因为很快 RAM 就会溢出?所以在这一点上,我用这样的东西达到了最高的性能。
K.clear_session()
random_model = load_model('path/to/' + str(random_model))
results = random_model(final_input_row)
此外,我有五个几乎一直在使用的模型,在这种情况下,性能更加重要。我在启动服务器时加载它们,并且我可以不断地访问它们。
graph = tf.get_default_graph()
with graph.as_default():
constant_model = load_model(
'path/to/constant_model')
预言:
with graph.as_default():
results = constant_model(final_input_row)
问题是K.clear_session()
我在加载过程中执行期间random_models
我从内存中丢失了它们。没有K.clear_session()
加载random_models
持续时间过长。你有什么想法我该如何解决这个问题?我什至可以使用完全不同的方法。
更新
我尝试做这样的事情:
class LoadModel:
def __init__(self, path):
self.path = path
self.sess = tf.Session(config=config)
self.graph = tf.get_default_graph()
K.set_session(self.sess)
with self.graph.as_default():
self.model = load_model(self.path)
def do_predictions(self, x):
with self.graph.as_default():
return self.model.predict(x)
然后当我执行时:
random_model = LoadModel('./path/to/random_model.h5')
results = random_model.do_predictions(final_input_row)
预测数据大约需要 3 秒。random_models
在我有很多模型的情况下,这是可以接受的。但是,constant_models
当我有五个并且我需要不断访问它们时,它会持续太长时间。到目前为止,我这样做的方式是在启动 Django 服务器的过程中加载模型,并将其存储在内存中,然后我就可以运行results = constant_model.do_predictions(final_input_row)
了,而且速度非常快。在我运行之前它可以正常工作random_models
,之后在我收到的每个请求中
tensorflow.python.framework.errors_impl.InvalidArgumentError: Tensor lstm_14_input:0, specified in either feed_devices or fetch_devices was not found in the Graph
[24/Jun/2019 10:11:00] "POST /request/ HTTP/1.1" 500 17326
[24/Jun/2019 10:11:02] "GET /model/ HTTP/1.1" 201 471
Exception ignored in: <bound method BaseSession._Callable.__del__ of <tensorflow.python.client.session.BaseSession._Callable object at 0x130069908>>
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1455, in __del__
self._session._session, self._handle, status)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 528, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: No such callable handle: 140576717540032
[24/Jun/2019 10:11:02] "GET /model/ HTTP/1.1" 201 471
[24/Jun/2019 10:11:07] "GET /model/ HTTP/1.1" 201 471
显然,如果我constant_model = LoadModel('./path/to/constant_model.h5')
在每次运行之前运行它可以正常工作,results = constant_model.do_predictions(final_input_row)
但是正如我所提到的那样它太慢了。任何想法如何解决这个问题?
更新2
我尝试以这种方式
session = tf.Session(config=config)
K.set_session(session)
class LoadModel:
def __init__(self, path):
self.path = path
self.graph = tf.get_default_graph()
with self.graph.as_default():
self.model = load_model(self.path)
def do_predictions(self, x):
with self.graph.as_default():
return self.model.predict(x)
但我仍然得到TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(1, 128), dtype=float32) is not an element of this graph.
解决方案 以下是我的工作解决方案。如果有人知道加载满足上述要求的模型的更有效方法,我将不胜感激。
class LoadModel:
def __init__(self, model_path):
with K.get_session().graph.as_default():
self.model = load_model(model_path)
self.model._make_predict_function()
self.graph = tf.get_default_graph()
def predict_data(self, data):
with self.graph.as_default():
output = self.model.predict(data)
return output
解决方案
正如您所说,您一直在使用所有模型,最好将它们保存在内存中,这样您就不会在每次发出请求时都加载和预测。例如,您应该定义一个加载模型的类,并为不同的模型保留此类的不同实例。这不是经过测试的代码,因此您可能需要进行一些更改。
# if you are using GPU
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
class LoadModel:
def __init__(self, path_to_model):
self.path = path
self.sess = tf.Session(config=config)
self.graph = tf.get_default_graph()
K.set_session(self.sess)
self.model = self.load()
def load(self):
with graph.as_default():
model = load_model(self.path)
return model
def do_predictions(self, x):
return self.model.predict(x)
推荐阅读
- python-3.x - 我给python中的变量重命名csv位置路径,它给出了错误unicodedecodeError
- redux - yield 调用返回未定义而不是已解决的承诺值
- python - 如何使用 win32com.client 通过 python 将 XLA 添加到 excel 中?
- opencart - 如何为变量创建条件?
- php - 代码需要更新以匹配 PHP7
- javascript - 使用Javascript动态循环遍历对象数组
- java - 为什么 URI 构造函数允许缺少协议(而 URL 不允许)?
- java - 这是否可以只调用一次“Okay Glass”
- reactjs - react js onclick函数不调用
- php - 将模板添加到 WHMCS