首页 > 解决方案 > Keras variable() memory leak

问题描述

I am new to Keras, and tensorflow in general, and have a problem. I am using some of the loss functions (binary_crossentropy and mean_squared_error mainly) to calculate the loss after prediction. Since Keras only accepts it's own variable type, I am creating one and supply it as an argument. This scenario is executed in a loop (with sleep) as such:

Get appropriate data -> predict -> calculate the lost -> return it.

Since I have multiple models that follow this pattern I created tensorflow graphs and sessions to prevent collision (also when exporting the models' weights I had problem with single graph and session so I had to create distinct ones for every single model).

However, now the memory is rising uncontrollably, from couple of MiB to 700MiB in couple of iterations. I am aware of Keras's clear_session() and gc.collect(), and I use them at the end of every iteration, but the problem is still present. Here I provide a code snippet, which is not the actual code, from the project. I've created separate script in order to isolate the problem:

import tensorflow as tf

from keras import backend as K
from keras.losses import binary_crossentropy, mean_squared_error

from time import time, sleep
import gc
from numpy.random import rand

from os import getpid
from psutil import Process

from csv import DictWriter
from keras import backend as K

this_process = Process(getpid())

graph = tf.Graph()
sess = tf.Session(graph=graph)

cnt = 0
max_c = 500

with open('/home/quark/Desktop/python-test/leak-7.csv', 'a') as file:
    writer = DictWriter(file, fieldnames=['time', 'mem'])
    writer.writeheader()

    while cnt < max_c:  
        with graph.as_default(), sess.as_default():         
            y_true = K.variable(rand(36, 6))
            y_pred = K.variable(rand(36, 6))

            rec_loss = K.eval(binary_crossentropy(y_true, y_pred))
            val_loss = K.eval(mean_squared_error(y_true, y_pred))

            writer.writerow({
                'time': int(time()),
                'mem': this_process.memory_info().rss
            })

        K.clear_session()
        gc.collect()

        cnt += 1
        print(max_c - cnt)
        sleep(0.1)

Additionally, I've added the plot of the memory usage: Keras memory leak

Any help is appreciated.

标签: pythontensorflowmemory-leakskeras

解决方案


I just removed the with statement (probably some tf code), and I don't see any leak. I believe there is a difference between the keras session and the tf default session. So you were not clearing the correct session with K.clear_session(). Probably using tf.reset_default_graph() could work too.

while True: 
    y_true = K.variable(rand(36, 6))
    y_pred = K.variable(rand(36, 6))

    val_loss = K.eval(binary_crossentropy(y_true, y_pred))
    rec_loss = K.eval(mean_squared_error(y_true, y_pred))

    K.clear_session()
    gc.collect()

    sleep(0.1)

推荐阅读