首页 > 解决方案 > 使用池的 TensorFlow 错误:无法腌制 _thread.RLock 对象

问题描述

我试图在 Tensorflow CPU 上实现一个大规模并行微分方程求解器(30k DEs),但内存不足(大约 30GB 矩阵)。所以我实现了一个基于批处理的求解器(解决一小段时间并保存数据 -> 设置新的初始值 -> 再次求解)。但问题依然存在。我了解到,在关闭 python 解释器之前,Tensorflow 不会清除内存。因此,根据有关 github 问题的信息,我尝试使用 pool 实现多处理解决方案,但在 Pooling 步骤中我不断收到“无法腌制 _thread.RLock 对象”。有人可以帮忙吗!

def dAdt(X,t):
  dX = // vector of differential
  return dX

global state_vector
global state

state_vector =  [0]*n // initial state

def tensor_process():
    with tf.Session() as sess:
        print("Session started...",end="")
        tf.global_variables_initializer().run()
        state = sess.run(tensor_state)
        sess.close()


n_batch = 3
t_batch = np.array_split(t,n_batch)


for n,i in enumerate(t_batch):
    print("Batch",(n+1),"Running...",end="")
    if n>0:
        i = np.append(i[0]-0.01,i)
    print("Session started...",end="")
    init_state = tf.constant(state_vector, dtype=tf.float64)
    tensor_state = tf.contrib.odeint_fixed(dAdt, init_state, i)
    with Pool(1) as p:
        p.apply_async(tensor_process).get()
    state_vector = state[-1,:]
    np.save("state.batch"+str(n+1),state)
    state=None

标签: pythontensorflowmultiprocessing

解决方案


由于许多原因,Tensorflow 不支持多处理,例如它无法分叉 TensorFlow 会话本身。如果您仍然想使用某种“多”的东西,试试这个(multiprocessing.pool.ThreadPool)对我有用:

https://stackoverflow.com/a/46049195/5276428

注意:我通过在线程上创建多个会话然后依次调用属于每个线程的每个会话变量来做到这一点。如果您的问题是内存,我认为可以通过减少输入批量大小来解决。


推荐阅读