python - 在递归循环期间分配给 TensorFlow 变量
问题描述
在 Tensorflow 1.9 中,我想创建一个网络,然后将网络的输出(预测)递归地反馈回网络的输入。在这个循环中,我想将网络所做的预测存储在一个列表中。
这是我的尝试:
# Define the number of steps over which to loop the network
num_steps = 5
# Define the network weights
weights_1 = np.random.uniform(0, 1, [1, 10]).astype(np.float32)
weights_2 = np.random.uniform(0, 1, [10, 1]).astype(np.float32)
# Create a variable to store the predictions, one for each loop
predictions = tf.Variable(np.zeros([num_steps, 1]), dtype=np.float32)
# Define the initial prediction to feed into the loop
initial_prediction = np.array([[0.1]], dtype=np.float32)
x = initial_prediction
# Loop through the predictions
for step_num in range(num_steps):
x = tf.matmul(x, weights_1)
x = tf.matmul(x, weights_2)
predictions[step_num-1].assign(x)
# Define the final prediction
final_prediction = x
# Start a session
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# Make the predictions
last_pred, all_preds = sess.run([final_prediction, predictions])
print(last_pred)
print(all_preds)
这打印出来:
[[48.8769]]
[[0.]
[0.]
[0.]
[0.]
[0.]]
因此,虽然 的值final_prediction
看起来是正确的,但 的值predictions
并不是我所期望的。尽管predictions
有predictions[step_num-1].assign(x)
.
请有人向我解释为什么这不起作用,以及我应该做什么?谢谢!
解决方案
发生这种情况是因为assign
它只是一个 TF 操作,就像其他任何操作一样,因此仅在需要时才执行。由于路径上的任何内容都不final_prediction
依赖于分配操作,并且predictions
只是一个变量,因此永远不会执行分配。
我认为最直接的解决方案是更换线路
predictions[step_num-1].assign(x)
经过
x = predictions[step_num-1].assign(x)
这是有效的,因为assign
它也返回了它分配的值。现在,计算final_prediction
TF 实际上需要“通过”assign
操作,因此应该执行分配。
另一种选择是使用tf.control_dependencies
which 是一种在计算其他操作时“强制”TF 计算特定操作的方法。但是在这种情况下,它可能有点棘手,因为我们想要强制执行的操作 ( assign
) 取决于在循环中计算的值,我不确定在这种情况下 TF 执行操作的顺序。以下应该有效:
for step_num in range(num_steps):
x = tf.matmul(x, weights_1)
x = tf.matmul(x, weights_2)
with tf.control_dependencies([predictions[step_num-1].assign(x)]):
x = tf.identity(x)
我们tf.identity
用作 noop 只是为了有一些东西可以用control_dependencies
. 我认为这是两者之间更灵活的选择。但是,它带有一些在文档中讨论的警告。
推荐阅读
- javascript - 无法使用 jQuery 在远程父 div 中获取元素
- mysql - 在最近的匹配字符串上连接两个表
- python - 带烧瓶的乘法表
- rxjs - 更好的订阅冗余解决方案
- android - 我们如何在 EditText 边框上写文字?
- batch-file - 如何在 Windows 批处理脚本中对动态变量值进行子串化
- mongodb - 我在尝试运行我的数据库时遇到错误
- python-3.x - 为什么这个 python 函数返回 len=7 而不是 len=6?
- javascript - 无法在 puppeteer-jest 测试框架中解析本地 JSON 数据?
- bash - 如何在要连接的每个文件之间添加新行 (\n)