tensorflow - TensorFlow - 在 fit_generator 中使用 class_weights 会导致内存泄漏
问题描述
在 TensorFlow 中,当在 fit_generator 中使用 class_weights 时,会导致训练过程不断消耗越来越多的 CPU RAM,直到耗尽。在每个 epoch 之后,内存使用量都会逐步增加。请参阅下面的可重现示例。为了使可重现的示例保持较小,我减小了数据集的大小和批量大小,这显示了内存增加的趋势。在使用我的实际数据进行训练时,它会耗尽 70 EPOCS 的全部 128GB RAM。
有人遇到这个问题或对此有任何建议吗?我的数据有不平衡的数据,所以我必须使用 class_weights 但我不能长时间运行训练。
在下面的代码示例中,如果您注释掉类权重,程序会在不消耗内存的情况下进行训练。
第一张图片显示了使用 class_weights 的内存使用情况,而第二张图片显示了没有使用 class_weights 的使用情况。
import tensorflow as tf
tf.enable_eager_execution()
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import CuDNNLSTM, Dense
from tensorflow.keras.optimizers import Adadelta
feature_count = 25
batch_size = 16
look_back = 5
target_groups = 10
def random_data_generator( ):
x_data_size =(batch_size, look_back, feature_count) # batches, lookback, features
x_data = np.random.uniform(low=-1.0, high=5, size=x_data_size)
y_data_size = (batch_size, target_groups)
Y_data = np.random.randint(low=1, high=21, size=y_data_size)
return x_data, Y_data
def get_simple_Dataset_generator():
while True:
yield random_data_generator()
def build_model():
model = Sequential()
model.add(CuDNNLSTM(feature_count,
batch_input_shape=(batch_size,look_back, feature_count),
stateful=False))
model.add(Dense(target_groups, activation='softmax'))
optimizer = Adadelta(learning_rate=1.0, epsilon=None)
model.compile(loss='categorical_crossentropy', optimizer=optimizer)
return model
def run_training():
model = build_model()
train_generator = get_simple_Dataset_generator()
validation_generator = get_simple_Dataset_generator()
class_weights = {0:2, 1:8, 2:1, 3:4, 4:8, 5:35, 6:30, 7:4, 8:5, 9:3}
model.fit_generator(generator = train_generator,
steps_per_epoch=1,
epochs=1000,
verbose=2,
validation_data=validation_generator,
validation_steps=20,
max_queue_size = 10,
workers = 0,
use_multiprocessing = False,
class_weight = class_weights
)
if __name__ == '__main__':
run_training()
解决方案
对于任何未来的用户,夜间构建中似乎存在一个错误,该错误似乎在随后的夜间构建中得到修复。错误报告中的更多详细信息。
推荐阅读
- django - 我可以在 Django 视图中破坏性地更改我的数据库吗?
- macos - 在 Mac 上安装 PANDAS,大问题
- npm - 如何更新 Npm 包?
- android - 为什么是“Fragment fragment = null”,这个片段定义的含义是什么?
- jenkins - Jenkins Pipeline - Codeception 测试 publishHTML - 没有 HTML 显示
- replace - 用pyspark中的中位数替换空值
- amazon-web-services - 从 DynamoDb 查询的 Python 脚本不提供所有项目
- javascript - 在javascript中过滤嵌套数组的问题
- python - 比较 csv 中的数据并根据字典查找打印更改
- reactjs - 映射数组和渲染结果