python - Tensorflow v1.15中TensorArray中的左移值
问题描述
我正在尝试使用以下代码在 aTensorArray
中使用 another左移值:TensorArray
import tensorflow as tf
an_array = tf.TensorArray(dtype=tf.float32, size=2, dynamic_size=False, clear_after_read=False, element_shape=(1, 2), name="First")
old_array = tf.TensorArray(dtype=tf.float32, size=2, dynamic_size=False, clear_after_read=False, element_shape=(1, 2), name="Second")
old_array = old_array.write(0, 2.*tf.ones((1, 2)))
old_array = old_array.write(1, 3.*tf.ones((1, 2)))
for _ in range(1, 5):
val = tf.random.normal(shape=(1, 2))
an_array = an_array.write(0, old_array.read(1))
an_array = an_array.write(1, val)
old_array = an_array.identity()
print(tf.Session().run([an_array.stack(), old_array.stack()]))
我加载了old_array
一些初始值。在循环中,我读取数组的第二个元素old_array
并将其提供给数组的第一个元素an_array
。然后将新值写入an_array
数组的第二个元素。的内容an_array
被复制到old_array
.
虽然在第一次迭代中,第一个元素严格来说不是数组的左移版本,而是基于初始条件的东西,但以后的迭代预计会给出数组an_array
的左移版本。an_array
当我运行这个脚本时,我得到了错误
tensorflow.python.framework.errors_impl.InvalidArgumentError:TensorArray First_1:无法写入 TensorArray 索引 0,因为它已被写入。
有人可以指出代码有什么问题吗?谢谢。
解决方案
我遇到了与您在tensorflow version 1.15.2
. 但它在tensorflow version 2.2.0
. 好像有tensorflow version 1.15.2
问题。
以下是 tensorflow 2.2.0 版本的运行结果。
张量流 2.2.0 -
#%tensorflow_version 2.x
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
print(tf.__version__)
an_array = tf.TensorArray(dtype=tf.float32, size=2, dynamic_size=False, clear_after_read=False, element_shape=(1, 2), name="First")
old_array = tf.TensorArray(dtype=tf.float32, size=2, dynamic_size=False, clear_after_read=False, element_shape=(1, 2), name="Second")
old_array = old_array.write(0, 2.*tf.ones((1, 2)))
old_array = old_array.write(1, 3.*tf.ones((1, 2)))
for _ in range(1, 5):
val = tf.random.normal(shape=(1, 2))
an_array = an_array.write(0, old_array.read(1))
an_array = an_array.write(1, val)
old_array = an_array.identity()
print(tf.compat.v1.Session().run([an_array.stack(), old_array.stack()]))
输出 -
2.2.0
[array([[[ 0.6800341 , -0.72848713]],
[[ 0.6648014 , -0.40344566]]], dtype=float32), array([[[ 0.6800341 , -0.72848713]],
[[ 0.6648014 , -0.40344566]]], dtype=float32)]
希望这能回答你的问题。快乐学习。
推荐阅读
- r - dplyr::select 中的错误 - 找不到对象“符号”
- d3.js - 如何在d3中绘制两条不同长度(不同x值)的曲线之间的区域/填充
- vue.js - 在 Vuex 中基于状态属性进行计算的正确方法是什么?
- android - 支持可穿戴设备的应用程序未出现在 Play 商店中,并且所有设备都不兼容
- c - 在使用 Cython(或 ctypes)调用的 Visual Studio 编译器上使用 OpenMP 的 C 扩展逐渐减速到停止
- sql - 填写滚动平均值的空白日期 - Snowflake 中的 CTE
- python - 在范围内查找 Pandas 数据帧的交集
- python - 换行有时无法在 .csv (Python/Pandas) 中正确显示
- huawei-mobile-services - HMS InAppPurchase - IapClient.isEnvReady() 失败,StatusCode 6003
- excel - 使用 vba 定义动态范围