python-3.x - ValueError:无法为具有形状“(?,784)”的张量“x:0”提供形状(784,)的值
问题描述
这是我第一次使用 TensorFlow。这个 ValueError似乎有很多查询,但是我没有得到任何缓解。我正在使用 notMNIST 数据集,它是拆分 70/30 训练测试。
错误消息似乎表明我的小批量存在问题。我已经打印了占位符的形状,重新调整了输入和标签数据,但没有成功。
import tensorflow as tf
tf.reset_default_graph()
num_inputs = 28*28 # Size of images in pixels
num_hidden1 = 500
num_hidden2 = 500
num_outputs = len(np.unique(y)) # Number of classes (labels)
learning_rate = 0.0011
inputs = tf.placeholder(tf.float32, shape=[None, num_inputs], name="x")
labels = tf.placeholder(tf.int32, shape=[None], name = "y")
print(np.expand_dims(inputs, axis=0))
print(np.expand_dims(labels, axis=0))
def neuron_layer(x, num_neurons, name, activation=None):
with tf.name_scope(name):
num_inputs = int(x.get_shape()[1])
stddev = 2 / np.sqrt(num_inputs)
init = tf.truncated_normal([num_inputs, num_neurons], stddev=stddev)
W = tf.Variable(init, name = "weights")
b = tf.Variable(tf.zeros([num_neurons]), name= "biases")
z = tf.matmul(x, W) + b
if activation == "sigmoid":
return tf.sigmoid(z)
elif activation == "relu":
return tf.nn.relu(z)
else:
return z
with tf.name_scope("dnn"):
hidden1 = neuron_layer(inputs, num_hidden1, "hidden1", activation="relu")
hidden2 = neuron_layer(hidden1, num_hidden2, "hidden2", activation="relu")
logits = neuron_layer(hidden2, num_outputs, "output")
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
loss = tf.reduce_mean(xentropy, name="loss")
with tf.name_scope("evaluation"):
correct = tf.nn.in_top_k(logits, labels, 1)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
grads = optimizer.compute_gradients(loss)
training_op = optimizer.apply_gradients(grads)
for var in tf.trainable_variables():
tf.summary.histogram(var.op.name + "/values", var)
for grad, var in grads:
if grad is not None:
tf.summary.histogram(var.op.name + "/gradients", grad)
# summary
accuracy_summary = tf.summary.scalar('accuracy', accuracy)
# merge all summary
tf.summary.histogram('hidden1/activations', hidden1)
tf.summary.histogram('hidden2/activations', hidden2)
merged = tf.summary.merge_all()
init = tf.global_variables_initializer()
saver = tf.train.Saver()
from datetime import datetime
now = datetime.utcnow().strftime("%Y%m%d%H%M%S")
root_logdir = "tf_logs/example03/dnn_final"
logdir = "{}/run-{}/".format(root_logdir, now)
train_writer = tf.summary.FileWriter("models/dnn0/train",
tf.get_default_graph())
test_writer = tf.summary.FileWriter("models/dnn0/test", tf.get_default_graph())
num_epochs = 50
batch_size = 128
with tf.Session() as sess:
init.run()
print("Epoch\tTrain accuracy\tTest accuracy")
for epoch in range(num_epochs):
for idx_start in range(0, x_train.shape[0], batch_size):
idx_end = num_epochs
x_batch, y_batch = x_train[batch_size], y_train[batch_size]
sess.run(training_op, feed_dict={inputs: x_batch, labels: y_batch})
summary_train, acc_train = sess.run([merged, accuracy],
feed_dict={x: x_batch, y: y_batch})
summary_test, acc_test = sess.run([accuracy_summary, accuracy],
feed_dict={x: x_test, y: y_test})
train_writer.add_summary(summary_train, epoch)
test_writer.add_summary(summary_test, epoch)
print("{}\t{}\t{}".format(epoch, acc_train, acc_test))
save_path = saver.save(sess, "models/dnn0.ckpt")
以下错误
ValueError:无法为具有形状“(?,784)”的张量“x:0”提供形状(784,)的值
发生在第 96 行
sess.run(training_op,feed_dict={输入:x_batch,标签:y_batch})
解决方案
在这条线上,你指的是inputs
和labels
sess.run(training_op, feed_dict={inputs: x_batch, labels: y_batch})
在下面的行中,
summary_train, acc_train = sess.run([merged, accuracy],
feed_dict={x: x_batch, y: y_batch})
summary_test, acc_test = sess.run([accuracy_summary, accuracy],
feed_dict={x: x_test, y: y_test})
你指的是x
和y
。将这些更改为相同。即它应该与占位符变量的值相同。(inputs
和labels
)
推荐阅读
- c++ - 错误:重新定义“类...”继承
- python - 在python中计算网格单元格中的点,np.histogramdd
- java - Jsonpath 在 Spring Boot Test 中为嵌套的 JSON 对象返回 null
- asynchronous - 如何结束(完成)基于另一个无限流创建的 async* 流?
- excel - 根据两个条件查找最大值,然后写入单元格
- scala - 如何将文件夹及其内容添加到 Playframework 和 Heroku 的标准路径?
- symfony - Twig 表单模板:无法从“Symfony\Component\Form\FormView”访问 PersistentCollection 属性。Symfony 4
- netsuite - 通过按钮创建 NetSuite 高级 PDF
- r - R基于匹配的嵌套列表名称对嵌套列表元素应用函数
- python - Python - 打开带有空格的文件时遇到问题