首页 > 解决方案 > 自定义基于梯度磁带的 fit() 函数比 Tf2 fit() 慢,并且存在内存泄漏

问题描述

fit()我在 TF2 中使用渐变胶带编写了一个自定义函数。装饰器@tf.function已在一些特定位置用于禁用急切执行,这应该可以加快整个代码的速度。

我的定制合身有三个主要问题:

1-它比 TF2 对应物慢很多(函数@tf.function上的装饰器train_batch不会提高速度,而放置在fit()函数上方的装饰器会)。由于某种原因,自定义的输出(见下图)上指示的时间fit()是错误的,实际上它是恒定的,每个 epoch 大约 10 秒。

2-内存使用量确实线性增加(见底部图)。

3-代码甚至在达到最大 GPU 内存使用量之前就被杀死(参见底部的代码输出)。

您可以在下面找到完整的代码,您应该可以直接运行它,因为它不需要任何数据或任何额外的代码行。

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import Input, Dense, Dropout, Conv1D, Flatten, MaxPooling1D
from tensorflow.keras.models import Model
import numpy as np
import warnings
import psutil
import time
import matplotlib.pyplot as plt


### DEFINE NEURAL NETWORK
def baseline_network():
    input = Input(shape=(128, 11))
    layer = Conv1D(64, 5, activation='relu', padding='same')(input) 
    layer = Flatten()(layer)
    layer = Dense(1024, activation='relu')(layer)
    output = Dense(257, activation=None)(layer)
                  
    model = Model(inputs=input, outputs=output)
    model.compile(optimizer='adam',
                  loss='mse', 
                  metrics=['mse'])
   
    return model


### DATA GENERATOR
def generator():
    
    # initialize numpy tensors
    X_out = np.zeros([6000, 128, 11])
    Y_out = np.zeros([6000, 257])

    ### CREATE DATA BATCHES ###
    while True:
        for i in range(32):
        
            yield X_out, Y_out


### CUSTOM FIT USING GRADIENT TAPE
class runGradientTape:
    
    def __init__(self, model, cost_function, iterator_train, steps_train, max_epochs, batch_size):
        
        self.model = model
        self.cost_function = cost_function
        self.iterator_train = iterator_train
        self.steps_train = steps_train
        self.max_epochs = max_epochs
        self.batch_size = batch_size


    def lossWrapper(self):
        
        #@tf.function
        def lossFunction(y_true, y_pred):
            # calculating loss and squeezing single dimensions away
            loss = tf.squeeze(self.cost_function(y_pred, y_true))
            # calculate mean over batches
            loss = tf.reduce_mean(loss)
            # return the loss
            return loss
        # returning the loss function as handle
        return lossFunction
    
    
    @tf.function
    def train_batch(self, x_batch_train, y_batch_train):
        # run gradient taping
        with tf.GradientTape() as tape:
            y_hat = self.model(x_batch_train, training=True)
            loss_value = self.loss_fn(y_batch_train, y_hat)

        # calculate gradients
        grads = tape.gradient(loss_value, self.model.trainable_variables)
        self.model.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) 
        
        return loss_value
    
    
    @tf.function
    def fit(self):
        
        # others
        warnings.filterwarnings("ignore")

        self.loss_fn = self.lossWrapper()
        progbar = tf.keras.utils.Progbar(self.steps_train)
        
        memory_used = []
        
        # ITERATE OVER EPOCHS 
        for epoch in range(self.max_epochs):
            
            print('Epoch = %d/%d' %(epoch+1, self.max_epochs))
            
            # iterate over steps in training generator
            for i in range(self.steps_train):
                
                x_batch_train, y_batch_train = next(self.iterator_train)
                
                loss_value = self.train_batch(x_batch_train, y_batch_train)
                progbar.update(i+1)
                
                memory_used.append(psutil.virtual_memory().used / 2 ** 30)
                #print('   memory used: ', memory_used[-1])
        
        #plt.plot(memory_used)
        #plt.title('Memory usage vs batch')
        #plt.savefig('mem_usage')

                
# parameters
steps_train = 100       
n_epochs = 5
batch_size = 32
baseline =  baseline_network()   
iterator_train = generator()

# TF2 fit()
print("TF2 fit")
start_time = time.time()
history = baseline.fit(iterator_train, steps_per_epoch=steps_train, epochs=n_epochs)
stop_time = time.time()
print("time elapsed (TF2 fit): ", stop_time-start_time)

#reset parameters
del baseline, iterator_train
baseline =  baseline_network()   
iterator_train = generator()

# custom fit
print("Custom fit")
start_time = time.time()
run_gradient_tape = runGradientTape(baseline, keras.losses.MSE, iterator_train, steps_train, n_epochs, batch_size)  
run_gradient_tape.fit()
stop_time = time.time()
print("time elapsed (custom fit): ", stop_time-start_time)

到目前为止我尝试过的一些建议:

1-gs.collect()在这种情况下没有任何改进。

2- 使用 tensorboard 分析代码(检查下面的配置文件和内存使用情况)。

您能否就如何解决前面提到的三个问题给我任何建议?

代码输出 在此处输入图像描述

内存使用情况 在此处输入图像描述 在此处输入图像描述

定时 在此处输入图像描述

标签: performancetensorflow2.0gradient

解决方案


推荐阅读