multithreading - 使用来自多处理池的 Keras 模型预测函数
问题描述
我看过一些关于这个主题的 类似 帖子,但似乎没有一个能解决我的问题。
我已经训练了一个 Keras 模型(仅限 CPU)并希望使用multithreading.Pool
. 但是,调用predict
只是挂起。没有抛出异常或任何东西。从主线程调用它工作正常。我尝试model._make_predict_function()
按照之前的建议使用,但这并不能为我解决这个问题。
我已经设置了一个 Jupyter 笔记本来重现这个(Keras==2.2.4,tensorflow==1.11.0):
In [1]: from keras.models import Sequential
from keras.layers import Dense
from multiprocessing.pool import Pool
In [2]: # Create sample model from Keras documentation
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])
# Generate dummy data
import numpy as np
data = np.random.random((1000, 100))
labels = np.random.randint(2, size=(1000, 1))
# Train the model, iterating on the data in batches of 32 samples
model.fit(data, labels, epochs=10, batch_size=32, verbose=0)
In [3]: test_data = np.random.random((1,100))
def predict(model, data):
return model.predict(data)
def do_predict(_=1):
print('Prediction:', predict(model, test_data))
print('Done')
In [4]: do_predict()
Out [4]: Prediction: [[0.5553096]]
Done
In [5]: with Pool(1) as pool:
pool.apply_async(do_predict, [1]).get()
pool.close()
pool.join()
在最后一步它只是挂起。谁能帮我找出这里发生了什么?不能predict
异步使用吗?
解决方案
推荐阅读
- flutter - 如何在导航到视图时将数据发送到控制器
- java - 将生产者连接到 kafka 主题
- javascript - 样式化多重自动完成策略
- python - Python AttributeError:'super'对象没有属性'testnet',但是当在super上调用__dict__时出现该属性?
- tensorflow - 这个模型是否代表了用于对象检测的 Faster R-CNN 模型?
- reactjs - 我可以使用 jsx 在同一个元素中添加新类吗?
- git - 在 Visual Studio Code 中处理 Alexa Skill 中的 git 问题
- java - 如何使用 VectorAssembler 设置火花数据集的 n 个特征?
- c# - ExecuteAsync 失败,但 Execute 工作正常
- python - 如何处理来自导入模块的未知类型的函数/方法