首页 > 解决方案 > tensorflow while循环减慢

问题描述

问题是一个 tensorflow while 循环(tf.while_loop)随着时间的推移而减慢。该循环应该返回一些矩阵。我通过字典提供所有输入。

我知道这个问题很可能是由于一遍又一遍地添加操作而污染了图表。我是 TF 初学者,对我来说,这里的图表受到污染的原因并不明显。非常感谢任何帮助。

def predict(self, actions, ...):


    feed_dict = {
        self.agent.actions: actions.reshape(-1, self.kwargs["dim_actions"]),
        ...
    }

    states_mu, states_var = self.session.run(self.agent.predict_states(), feed_dict=feed_dict)

    return states_mu, states_var


def predict_states(self):
   ...

    def loop_cond(i, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov):
        return i < self.episode_length

    def loop_body(i, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov):
        state_mu_i = state_mus[-1][None, :]
        ...
        state_var_tf = state_vars_tf[-1][None, :, :]

        #Some math operations
        ...

        new_state_mu = state_mu_i + delta_mu
        new_state_var = state_var_i + delta_var + inp_out_cov

        new_mu_tf, new_var_tf, inp_tf_cov = some_transform(
            new_state_mu, ....)

        state_mus = tf.concat([state_mus, new_state_mu], 0)
        ...
        state_vars_tf = tf.concat([state_vars_tf, new_var_tf], 0)

        i += 1

        return i, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov

    loop_step = tf.constant(0, tf.int32)
    init_mus_tf, init_vars_tf, inp_tf_cov = some_transform(
        self.state_mu, self.state_var, self.dim_angles)

    loop_vars = [
        loop_step,
        self.state_mu,
        self.state_var,
        init_mus_tf,
        init_vars_tf,
        inp_tf_cov]

    shapes = [loop_step.get_shape(),
              tf.TensorShape([None, self.dim_states]),
              tf.TensorShape([None, self.dim_states, self.dim_states]),
              tf.TensorShape([None, self.dim_states_tf]),
              tf.TensorShape([None, self.dim_states_tf, self.dim_states_tf]),
              inp_tf_cov.get_shape()]

    _, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov = tf.while_loop(
        loop_cond,
        loop_body,
        loop_vars=loop_vars,
        shape_invariants=shapes)

    return state_mus_tf[1:], state_vars_tf[1:]

该循环被多次调用。它在一次运行中减慢,即在每次迭代之后,甚至在重复调用之后更慢。每次运行的迭代速度从上次运行结束的地方开始。例如,在第一次运行开始时,每次迭代需要 1 秒,在第一次运行结束时,每次迭代需要 3 秒。在第二次运行开始时,每次迭代需要 3 秒,......直到让它运行变得不可行(例如,每次迭代需要 100 秒)。

标签: pythontensorflow

解决方案


代码看起来大部分都很好,但是您应该predict_states只调用一次,当您创建类的实例时(或在其他一些初始化步骤中),并将返回值存储在类属性中。例如:

def __init__(self, ...):
    # ...
    self.states_mu_tf, self.states_var_tf = self.agent.predict_states()

然后,您使用这些属性predict

states_mu, states_var = self.session.run((self.states_mu_tf, self.states_var_tf),
                                         feed_dict=feed_dict)

这样您就不会在图中重新创建操作。


推荐阅读