python - 如何阻止 CUDA 为每个训练 keras 模型的子进程重新初始化?
问题描述
我正在使用 CUDA/CUDNN 在我的 GPU 上训练多个 tensorflow keras 模型(用于尝试优化超参数的进化算法)。最初,该程序会在几代后因内存不足错误而崩溃。最终,我发现为每个模型使用一个新的子进程会自动清除 GPU 内存。
但是,每个进程似乎都在重新初始化 CUDA(从 .dll 文件加载动态库),这非常耗时。有什么方法可以避免这种情况吗?
代码粘贴在下面。为每个 idual 调用函数“fitness_wrapper” indiv
。
def fitness_wrapper(indiv):
fit = multi.processing.Value('d', 0.0)
if __name__ == '__main__':
process = multiprocessing.Process(target=fitness, args=(indiv, fit))
process.start()
process.join()
return (fit.value,)
def fitness(indiv, fit):
model = tf.keras.Sequential.from_config(indiv['architecture'])
optimizer_dict = indiv['optimizer']
opt = tf.keras.optimizers.Adam(learning_rate=optimizer_dict['lr'], beta_1=optimizer_dict['b1'],
beta_2=optimizer_dict['b2'],
epsilon=optimizer_dict['epsilon'])
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
model.fit(data_split[0], data_split[2], batch_size=32, epochs=5)
fit = model.evaluate(data_split[1], data_split[3])[1]
解决方案
原来解决方案是在每个模型(而不是子流程)之后使用 tensorflow.backend.clear_session() 。我以前试过这个,但它没有用,但由于某种原因,这次它修复了所有问题。
显然,您还应该删除模型并调用 reset_default_graph()。
def fitness(indiv, fit):
model = tf.keras.Sequential.from_config(indiv['architecture'])
optimizer_dict = indiv['optimizer']
opt = tf.keras.optimizers.Adam(learning_rate=optimizer_dict['lr'], beta_1=optimizer_dict['b1'],
beta_2=optimizer_dict['b2'],
epsilon=optimizer_dict['epsilon'])
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
model.fit(data_split[0], data_split[2], batch_size=32, epochs=5)
fit = model.evaluate(data_split[1], data_split[3])[1]
del model
tf.keras.backend.clear_session()
tf.compat.v1.reset_default_graph()
return fit
推荐阅读
- javascript - 如何将更改时的列值映射到handsontable中的第1列值?
- javascript - 当为“var”和“let”分配一个引发错误的函数的返回值时,是什么导致了它们之间的不同行为
- java - 仅当表存在时如何从表中选择值
- c++ - MacOS 上 Qt GUI 小部件的异常大小
- mysql - 有没有办法将一个字符串拆分为多个字符串?
- node.js - 我无法在 heroku 中部署 google oauth 应用程序
- java - 泰语字符转换为问号
- java - Java runnable 在 zeromq 发布者断开连接时死亡
- python - 结合 np.equal 和 np.less 创建单个数据框?
- php - xml中的错误字节