python - 如何将张量类型值传递给变量的形状参数?
问题描述
我遇到了一个问题,可以总结如下:
foo = tf.constant(3)
foo_variable = tf.get_variable("foo", shape=[foo], dtype=tf.int32)
变量的形状必须取决于张量的值(foo
这里只是对其他操作的计算结果的抽象)
这里的错误是The shape of a variable can not be a Tensor object
如何解决这个问题?
解决方案
创建一个具有张量指定形状的张量初始化器,foo
然后使用此初始化器实例化一个新变量validate_shape=False
:
import tensorflow as tf
x = tf.placeholder(tf.int32, shape=())
shape = tf.constant([2, 3]) + x
init = tf.zeros(shape, dtype=tf.int32)
v = tf.get_variable('foo', initializer=init, validate_shape=False)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer(), {x: 1})
print(v.eval())
# [[0 0 0 0]
# [0 0 0 0]
# [0 0 0 0]]
推荐阅读
- javascript - JS克隆后用按钮删除div
- azure - 适用于 Apple iOS 13 的 Azure 推送通知服务失败
- fpga - 如何修复 quartus 14.1 web edition on linux 在使用几分钟后抛出的这个错误?
- serverless-framework - 文件 '../foo.ts' 不在 rootDir 下。rootDir 应包含所有源文件
- python - 由于 json.load 读取大 json 文件时出错
- python - 为什么我在 python 3 中收到“function upper(bytea)”错误?
- node.js - 我正在尝试从公共 s3 存储桶中复制一个文件并将其放入我的 dynamodb 数据库中。我在这里做错了什么?
- javascript - 单击 li 时未应用活动类
- mysql - 带有 Angular 8、节点和 mysql 的数据表
- azure - 解决对 Adaptive Cards 1.2 缺乏支持的问题