python - 在多处理中使用 model.predict (Keras + TF)
问题描述
我有以下问题。我正在使用 Tensorflow Keras 模型来评估连续传感器数据。我的模型输入由 15 个传感器数据帧组成。因为函数 model.predict() 需要将近 1 秒,所以我想异步执行这个函数,这样我就可以在这个时间段内收集下一个数据帧。为此,我创建了一个带有多处理库和一个用于 model.predict 的函数的池。我的代码看起来像这样:
def predictData(data):
return model.predict(data)
global model
model = tf.keras.models.load_model("Network.h5")
model._make_predict_function()
p = Pool(processes = 4)
...
res = p.apply_async(predictData, ([[iinput]],))
print(res.get(timeout = 10))
现在我在调用 predictData() 时总是遇到超时错误。似乎 model.predict() 无法正常工作。我做错了什么?
解决方案
可以在多个并发 python 进程中运行多个预测,只需要在每个独立进程中构建自己的 tensorflow 计算图,然后调用 keras.model.predict
编写一个将与多处理模块(使用 Process 或 Pool 类)一起使用的函数,在此函数中,您应该构建模型、张量流图和任何您需要的东西,设置所有张量流和 keras 变量,然后您可以调用预测方法就可以了,然后将结果通过管道传回您的主进程。
例如:
def f(data):
import tensorflow, keras
configure your tensorflow and keras settings (e.g. GPU/CPU usage)
keras_model = build_your_keras_model()
result = keras_model.predict(data)
return result
if __main__ = '__main__':
p = Pool(processes = 4)
res = p.apply_async(f, (data,))
print(res.get(timeout = 10))