python - tensorflow(tf)变量初始化形状错误?
问题描述
下面的代码旨在通过一个while循环构建“sy_logprob_n”,并在每个循环中填充一行dim。但是当我运行它时,它说
ValueError:使用输入暗淡 0 的索引超出范围;对于“while/strided_slice_1”(操作:“StridedSlice”),输入只有 0 个暗淡,输入形状为:[]、[1]、[1]、[1],计算输入张量为:input[3] = <1> .
奇怪的是,当我将形状打印为 时print(tf.shape(sy_logprob_n),tf.shape(sy_ac_na))
,它们的两种形状是不同的:
张量("Shape_2:0", shape=(0,), dtype=int32) 张量("Shape_3:0", shape=(2,), dtype=int32)
但是“sy_logprob_n”被初始化为“sy_ac_na”的形状。
有人有想法吗?提前致谢!!!
sy_ac_na = tf.placeholder(shape=[None, ac_dim], name="ac", dtype=tf.float32)
batch_size=tf.shape(sy_ob_no)[0]
sy_mean = build_mlp(sy_ob_no, ac_dim, 'mean', n_layers, size) # shape (batch,ac_dim)
sy_logstd = tf.Variable(initial_value=0,name='std',expected_shape=[ac_dim],dtype=tf.float32) # logstd should just be a trainable variable, not a network output.
sy_sampled_ac=tf.random_normal(tf.shape(sy_ac_na),mean=sy_mean,stddev=tf.exp(sy_logstd),seed=seed) # shape(batch,ac_dim)
sy_logprob_n=tf.Variable(initial_value=0,expected_shape=tf.shape(sy_ac_na),name='sy_logprob_n',dtype=tf.float32)
print(tf.shape(sy_logprob_n),tf.shape(sy_ac_na))
def cond(sy_logprob_n,i,sy_ac_na):
return tf.less(i,batch_size)
def body(sy_logprob_n,i,sy_ac_na):
distribution=tf.distributions.Normal(sy_mean,tf.exp(sy_logstd))
log_prob=distribution.log_prob(sy_ac_na[i,:])
tf.assign(sy_logprob_n[i],log_prob)
i+=1
return sy_logprob_n
sy_logprob_n = tf.squeeze(tf.while_loop(cond,body,[sy_logprob_n,0,sy_ac_na])) # shape (batch,)
解决方案
即使您指定了expected_shape
,tf.Variable
也会采用initial_value
(此处为 0 )的形状作为其形状。自 v1.0.0 以来,该expected_shape
参数似乎已被弃用。我不确定为什么它仍然记录在tf.Variable
. 我刚刚在 Github 上提交了一个问题。
推荐阅读
- c# - 我有这个问题不是所有的代码路径都返回一个值
- if-statement - 多个 IF if 语句闪亮
- python - 你如何测试一个 python 类型的协议是另一个协议的子类?
- sql-server - ASP.NET Core Web API 服务无法连接到同一网络上的 Docker SQL Server
- javascript - 将项目添加到 JavaScript 中的对象数组中
- c# - 在 C# 中播放音频时出现白噪声
- java - 如何配置 GraphQL SPQR 以使用 Gson 而不是 Jackson
- xamarin - 框架内的 Xamarin 表单标签未显示
- r - 如何反转该图的 x 轴?
- phpmailer - 此脚本当前仅支持 Google、Microsoft 和 Yahoo OAuth2 提供程序。- php邮件程序