首页 > 解决方案 > 如何清除/删除张量流中的张量?

问题描述

据我了解,tf.reset_default_graph()只创建一个新图并将其设置为等于默认图。因此,先前创建的张量将只是在占用内存。我还阅读了未引用的张量不是垃圾收集的(就像 Python 中的普通变量一样)。

如果我正在运行交叉验证以搜索一组超参数并因此一次又一次地创建相同的图,我如何摆脱先前创建的张量?

标签: pythontensorflowgarbage-collectiondeep-learning

解决方案


我在设计实验时遇到了同样的问题,在研究了这个问题之后,唯一对我有用的解决方案就是这个。正如您在该链接中看到的那样,这似乎是一个设计缺陷,TF 团队似乎并不关心修复。

解决方案是为每个交叉验证迭代创建一个新流程。因此,当进程完成时,系统会杀死它并自动释放资源。

import multiprocessing

def evaluate(...):
    import tensorflow as tf
    # Your logic

for ... in cross_valiadtion_loop:
    process_eval = multiprocessing.Process(target=evaluate, args=(...))
    process_eval.start()
    process_eval.join()

推荐阅读