首页 > 解决方案 > 使用来自多处理池的 Keras 模型预测函数

问题描述

我看过一些关于这个主题的 类似 帖子,但似乎没有一个能解决我的问题。

我已经训练了一个 Keras 模型(仅限 CPU)并希望使用multithreading.Pool. 但是,调用predict只是挂起。没有抛出异常或任何东西。从主线程调用它工作正常。我尝试model._make_predict_function()按照之前的建议使用,但这并不能为我解决这个问题。

我已经设置了一个 Jupyter 笔记本来重现这个(Keras==2.2.4,tensorflow==1.11.0):

In  [1]: from keras.models import Sequential
         from keras.layers import Dense
         from multiprocessing.pool import Pool

In  [2]: # Create sample model from Keras documentation
         model = Sequential()
         model.add(Dense(32, activation='relu', input_dim=100))
         model.add(Dense(1, activation='sigmoid'))
         model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])

         # Generate dummy data
         import numpy as np
         data = np.random.random((1000, 100))
         labels = np.random.randint(2, size=(1000, 1))

         # Train the model, iterating on the data in batches of 32 samples
         model.fit(data, labels, epochs=10, batch_size=32, verbose=0)

In  [3]: test_data = np.random.random((1,100))

         def predict(model, data):
             return model.predict(data)

         def do_predict(_=1):
             print('Prediction:', predict(model, test_data))
             print('Done')

In  [4]: do_predict()
Out [4]: Prediction: [[0.5553096]]
         Done

In  [5]: with Pool(1) as pool:
             pool.apply_async(do_predict, [1]).get()
             pool.close()
             pool.join()

在最后一步它只是挂起。谁能帮我找出这里发生了什么?不能predict异步使用吗?

标签: multithreadingkeras

解决方案


推荐阅读