首页 > 解决方案 > 在多处理中使用 model.predict (Keras + TF)

问题描述

我有以下问题。我正在使用 Tensorflow Keras 模型来评估连续传感器数据。我的模型输入由 15 个传感器数据帧组成。因为函数 model.predict() 需要将近 1 秒,所以我想异步执行这个函数,这样我就可以在这个时间段内收集下一个数据帧。为此,我创建了一个带有多处理库和一个用于 model.predict 的函数的池。我的代码看起来像这样:

def predictData(data): 
   return model.predict(data)

global model
model = tf.keras.models.load_model("Network.h5")
model._make_predict_function()

p = Pool(processes = 4)
...
res = p.apply_async(predictData, ([[iinput]],))
print(res.get(timeout = 10))

现在我在调用 predictData() 时总是遇到超时错误。似乎 model.predict() 无法正常工作。我做错了什么?

标签: pythontensorflowkerasmultiprocessing

解决方案


可以在多个并发 python 进程中运行多个预测,只需要在每个独立进程中构建自己的 tensorflow 计算图,然后调用 keras.model.predict

编写一个将与多处理模块(使用 Process 或 Pool 类)一起使用的函数,在此函数中,您应该构建模型、张量流图和任何您需要的东西,设置所有张量流和 keras 变量,然后您可以调用预测方法就可以了,然后将结果通过管道传回您的主进程。

例如:

    def f(data):

          import tensorflow, keras

          configure your tensorflow and keras settings (e.g.  GPU/CPU usage)

          keras_model = build_your_keras_model()

          result = keras_model.predict(data)

          return result

    if __main__ = '__main__':

          p = Pool(processes = 4)

          res = p.apply_async(f, (data,))

          print(res.get(timeout = 10))

推荐阅读