python - 为什么 TensorFlow 会提示我将错误的形状和类型输入到占位符中?
问题描述
我想不通。我一直在来回走动,我知道我可以复制和粘贴一个工作教程,但我想了解为什么这不起作用。
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
#simple constants
learning_rate = .01
batch_size = 100
training_epoch = 10
t = 0
l = t
#gather the data
x_train = mnist.train.images
y_train = mnist.train.labels
batch_count = int(len(x_train)/batch_size)
#Set the variables
Y_ = tf.placeholder(tf.float32, [None,10], name = 'Labels')
X = tf.placeholder(tf.float32,[None,784], name = 'Inputs')
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
#Build the graph (Y = WX + b)
Y = tf.nn.softmax(tf.matmul(X,W) + b, name = 'softmax')
cross_entropy = -tf.reduce_mean(Y_ * tf.log(Y)) * 1000.0
correct_prediction = tf.equal(tf.argmax(Y,1), tf.argmax(Y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(training_epoch):
for i in range(batch_count):
t += batch_size
print(y_train[l:t].shape)
print(x_train[l:t].shape)
print(y_train[l:t].dtype)
sess.run(train_step,feed_dict={X: x_train[l:t], Y: y_train[l:t]})
l = t
print('Epoch = ', epoch)
print("Accuracy: ", accuracy.eval(feed_dict={X: x_test, Y_: y_test}))
print('Done')
错误信息:
InvalidArgumentError: You must feed a value for placeholder tensor 'Labels_2' with dtype float and shape [?,10]
[[Node: Labels_2 = Placeholder[dtype=DT_FLOAT, shape=[?,10], _device="/job:localhost/replica:0/task:0/device:GPU:0"]
我也明白,要让它发挥作用,我还需要添加更多内容,但我现在想自己努力解决这个问题。我在 jupyter 笔记本上运行它。我很肯定它y_train
有一个形状 (100, 10) 和一个 float64 类型。
我已经被困了几天,所以我很感激帮助。
解决方案
您需要在Y_
调用时输入占位符张量sess.run
。
在feed_dict
,只需更改Y: y_train[l:t]
为Y_: y_train[l:t]
。这将y_train[l:t]
输入占位符。
推荐阅读
- snowflake-cloud-data-platform - 在雪花中将仓库的大小从 x-small 调整为 medium
- python - Pandas - 基于解析其他日期列添加新日期列
- json - 如何将数据集从数据库查询转换为特定的 JSON 格式以输入到 REST api
- javascript - 使用嵌套循环计算重复项的数量。如何避免比较具有相同索引的两个字符
- java - Spring Boot 应用程序的嵌入式 HTTP 服务器在哪里?
- django - 如何在annotate django中获得布尔结果?
- javascript - 如何从 mongodb 查询回调函数返回数据?
- java - for循环内for循环代码优化
- android - 我可以在物理 android 设备 5.1 棒棒糖上运行我的颤振应用程序吗
- unity3d - 从我的工作着色器转换为统一计算着色器时流体模拟中的奇怪行为