首页 > 解决方案 > Tensorflow v1.15中TensorArray中的左移值

问题描述

我正在尝试使用以下代码在 aTensorArray中使用 another左移值:TensorArray

import tensorflow as tf

an_array = tf.TensorArray(dtype=tf.float32, size=2, dynamic_size=False, clear_after_read=False, element_shape=(1, 2), name="First")
old_array = tf.TensorArray(dtype=tf.float32, size=2, dynamic_size=False, clear_after_read=False, element_shape=(1, 2), name="Second")
old_array = old_array.write(0, 2.*tf.ones((1, 2)))
old_array = old_array.write(1, 3.*tf.ones((1, 2)))
for _ in range(1, 5):
    val = tf.random.normal(shape=(1, 2))
    an_array = an_array.write(0, old_array.read(1))
    an_array = an_array.write(1, val)
    old_array = an_array.identity()
print(tf.Session().run([an_array.stack(), old_array.stack()]))

我加载了old_array一些初始值。在循环中,我读取数组的第二个元素old_array并将其提供给数组的第一个元素an_array。然后将新值写入an_array数组的第二个元素。的内容an_array被复制到old_array.

虽然在第一次迭代中,第一个元素严格来说不是数组的左移版本,而是基于初始条件的东西,但以后的迭代预计会给出数组an_array的左移版本。an_array

当我运行这个脚本时,我得到了错误

tensorflow.python.framework.errors_impl.InvalidArgumentError:TensorArray First_1:无法写入 TensorArray 索引 0,因为它已被写入。

有人可以指出代码有什么问题吗?谢谢。

标签: pythonarrayspython-3.xtensorflow

解决方案


我遇到了与您在tensorflow version 1.15.2. 但它在tensorflow version 2.2.0. 好像有tensorflow version 1.15.2问题。

以下是 tensorflow 2.2.0 版本的运行结果。

张量流 2.2.0 -

#%tensorflow_version 2.x
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
print(tf.__version__)

an_array = tf.TensorArray(dtype=tf.float32, size=2, dynamic_size=False, clear_after_read=False, element_shape=(1, 2), name="First")
old_array = tf.TensorArray(dtype=tf.float32, size=2, dynamic_size=False, clear_after_read=False, element_shape=(1, 2), name="Second")
old_array = old_array.write(0, 2.*tf.ones((1, 2)))
old_array = old_array.write(1, 3.*tf.ones((1, 2)))
for _ in range(1, 5):
    val = tf.random.normal(shape=(1, 2))
    an_array = an_array.write(0, old_array.read(1))
    an_array = an_array.write(1, val)
    old_array = an_array.identity()

print(tf.compat.v1.Session().run([an_array.stack(), old_array.stack()]))

输出 -

2.2.0
[array([[[ 0.6800341 , -0.72848713]],

       [[ 0.6648014 , -0.40344566]]], dtype=float32), array([[[ 0.6800341 , -0.72848713]],

       [[ 0.6648014 , -0.40344566]]], dtype=float32)]

希望这能回答你的问题。快乐学习。


推荐阅读