首页 > 解决方案 > 多次加载 Keras 模型时 GPU 内存泄漏

问题描述

我有一个文件train.py来训练神经网络模型tensorflow.keras,最好的模型将best_train.h5根据train_loss. 培训时间约为7天。

我有另一个文件test.py来测试运行期间保存的模型train.py。在模型中,我每小时test.py加载一次以查看测试性能。best_train.h5代码如下:

for i in range(7*24):
  time.sleep(1*60*60)
  model = tf.keras.models.load_model('best_train.h5')
  model.predict(test_data)

我发现每次加载best-train.h5时,占用的GPU内存都会增加。而在大约 200 次迭代之后,GPU 内存就会耗尽。将发生 OOM 错误。

我知道tf.clear_session()可以释放 GPU 内存。但是此命令将清除所有会话和图形。我不是我想要的。因为test.py我还持有其他模型。

标签: pythontensorflowkeras

解决方案


由于 Keras 在所有模型之间共享全局会话。您可以创建一个新图表并分配一个会话仅用于预测:

self.graph = tf.Graph()
with self.graph.as_default():
   self.session = tf.Session(graph=self.graph)
   with self.session.as_default(): 
       # Load your model and preform prediction

完成预测后,您的 GPU 内存应立即释放。


推荐阅读