tensorflow2.0 - ValueError:形状在assign_add()中的等级必须相等
问题描述
我正在阅读TF2 中 Tensorflow r2.0中的 tf.Variable:
import tensorflow as tf
# Create a variable.
w = tf.constant([1, 2, 3, 4], tf.float32, shape=[2, 2])
# Use the variable in the graph like any Tensor.
y = tf.matmul(w,tf.constant([7, 8, 9, 10], tf.float32, shape=[2, 2]))
v= tf.Variable(w)
# The overloaded operators are available too.
z = tf.sigmoid(w + y)
tf.shape(z)
# Assign a new value to the variable with `assign()` or a related method.
v.assign(w + 1)
v.assign_add(tf.constant([1.0, 21]))
ValueError: Shapes must be equal rank, but are 2 and 1 for 'AssignAddVariableOp_4' (op: 'AssignAddVariableOp') with input shapes: [], 2 .
还有为什么下面的返回错误?
tf.shape(v) == tf.shape(tf.constant([1.0, 21],tf.float32))
我的另一个问题是,当我们在 TF 2 中时,我们不应该再使用 tf.Session() 了,对吗?似乎我们永远不应该运行 session.run(),但是 API 文档密钥使用 tf.compat.v1 等执行它。那么为什么他们在 TF2 文档中使用它呢?
任何帮助,将不胜感激。
CS
解决方案
正如它在错误中明确指出的那样,它期望形状为 [2,2] 的assign_add
v 上的形状为 [2,2]。如果您尝试给出除您尝试执行的张量的初始形状之外的任何形状,assign_add
则会给出错误。
下面是修改后的代码,具有操作的预期形状。
import tensorflow as tf
# Create a variable.
w = tf.constant([1, 2, 3, 4], tf.float32, shape=[2, 2])
# Use the variable in the graph like any Tensor.
y = tf.matmul(w,tf.constant([7, 8, 9, 10], tf.float32, shape=[2, 2]))
v= tf.Variable(w)
# The overloaded operators are available too.
z = tf.sigmoid(w + y)
tf.shape(z)
# Assign a new value to the variable with `assign()` or a related method.
v.assign(w + 1)
print(v)
v.assign_add(tf.constant([1, 2, 3, 4], tf.float32, shape=[2, 2]))
v 的输出:
<tf.Variable 'UnreadVariable' shape=(2, 2) dtype=float32, numpy=
array([[3., 5.],
[7., 9.]], dtype=float32)>
现在下面的张量比较正在返回True
。
tf.shape(v) == tf.shape(tf.constant([1.0, 21],tf.float32))
<tf.Tensor: shape=(2,), dtype=bool, numpy=array([ True, True])>
谈到你的tf.Session()
问题,在 TensorFlow 2.0 中,Eager Execution 默认是启用的,但如果你需要禁用 Eager Execution 并且可以tf.Session
像下面这样使用。
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
hello = tf.constant('Hello, TensorFlow!')
sess = tf.compat.v1.Session()
print(sess.run(hello))
推荐阅读
- javascript - Next.js中如何在组件之间传递状态值如何制作StateProvider
- assembly - 当我尝试编译时出现此错误:boot.s:18: `_start' 的多个定义;
- malloc - 请求内存 sbrk 时内存对齐是怎么回事?
- spring-boot - maven spring boot:配置jvmbehin代理
- eclipse - Eclipse 打开方式选项
- javascript - javascript mqtt websocket在localhost中工作正常,在https服务器中不起作用
- javascript - Node.js 快速路由器 url/url 不起作用
- javascript - 未处理的拒绝类型错误:无法读取未定义的属性“推送”
- python - Python Xlwings 执行问题
- c++ - 继承访问问题