首页 > 解决方案 > 在递归循环期间分配给 TensorFlow 变量

问题描述

在 Tensorflow 1.9 中,我想创建一个网络,然后将网络的输出(预测)递归地反馈回网络的输入。在这个循环中,我想将网络所做的预测存储在一个列表中。

这是我的尝试:

    # Define the number of steps over which to loop the network
    num_steps = 5

    # Define the network weights
    weights_1 = np.random.uniform(0, 1, [1, 10]).astype(np.float32)
    weights_2 = np.random.uniform(0, 1, [10, 1]).astype(np.float32)

    # Create a variable to store the predictions, one for each loop
    predictions = tf.Variable(np.zeros([num_steps, 1]), dtype=np.float32)

    # Define the initial prediction to feed into the loop
    initial_prediction = np.array([[0.1]], dtype=np.float32)
    x = initial_prediction

    # Loop through the predictions
    for step_num in range(num_steps):
        x = tf.matmul(x, weights_1)
        x = tf.matmul(x, weights_2)
        predictions[step_num-1].assign(x)

    # Define the final prediction
    final_prediction = x

    # Start a session
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # Make the predictions
    last_pred, all_preds = sess.run([final_prediction, predictions])
    print(last_pred)
    print(all_preds)

这打印出来:

[[48.8769]]

[[0.]
 [0.]
 [0.]
 [0.]
 [0.]]

因此,虽然 的值final_prediction看起来是正确的,但 的值predictions并不是我所期望的。尽管predictionspredictions[step_num-1].assign(x).

请有人向我解释为什么这不起作用,以及我应该做什么?谢谢!

标签: pythontensorflow

解决方案


发生这种情况是因为assign它只是一个 TF 操作,就像其他任何操作一样,因此仅在需要时才执行。由于路径上的任何内容都不final_prediction依赖于分配操作,并且predictions只是一个变量,因此永远不会执行分配。

我认为最直接的解决方案是更换线路

predictions[step_num-1].assign(x)

经过

x = predictions[step_num-1].assign(x)

这是有效的,因为assign它也返回了它分配的值。现在,计算final_predictionTF 实际上需要“通过”assign操作,因此应该执行分配。

另一种选择是使用tf.control_dependencieswhich 是一种在计算其他操作时“强制”TF 计算特定操作的方法。但是在这种情况下,它可能有点棘手,因为我们想要强制执行的操作 ( assign) 取决于在循环中计算的值,我不确定在这种情况下 TF 执行操作的顺序。以下应该有效:

for step_num in range(num_steps):
    x = tf.matmul(x, weights_1)
    x = tf.matmul(x, weights_2)
    with tf.control_dependencies([predictions[step_num-1].assign(x)]):
        x = tf.identity(x)

我们tf.identity用作 noop 只是为了有一些东西可以用control_dependencies. 我认为这是两者之间更灵活的选择。但是,它带有一些在文档中讨论的警告。


推荐阅读