python - 无法为具有形状“(?,1)”的张量“Placeholder_1:0”提供形状(100,)的值
问题描述
如问题中所述,Tensorflow 1.14 给了我问题“无法为 Tensor 'Placeholder_1:0' 提供形状 (100,) 的值,其形状为 '(?, 1)'”。此错误发生在代码的会话段 [代码的最后一个单元格]。我该如何解决?我无法理解错误在哪里。代码采用 iPython 格式。提前致谢。
代码如下:
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import tensorflow as tf
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import random
import math
import time
import h5py
# In[2]:
if len(tf.config.experimental.list_physical_devices('GPU')) > 0:
tf.debugging.set_log_device_placement(True)
# In[3]:
df = pd.read_hdf('../Dataset/dataset.h5',key='df')
# In[4]:
df.head()
# In[5]:
df.pop('App')
df_y = df.pop('Label')
df_x = df
# In[6]:
trn_x, val_x, trn_y, val_y = train_test_split(df_x, df_y, test_size=0.3)
# In[7]:
trn_x.head()
# In[8]:
trn_y.head()
# In[9]:
val_x.head()
# In[10]:
val_y.head()
# In[11]:
train_size,num_features = trn_x.shape
print(train_size)
# In[12]:
x = tf.placeholder("float", shape=[None, num_features])
y = tf.placeholder("float", shape=[None,1])
# In[13]:
w = tf.Variable(np.random.rand(num_features,1), dtype = np.float32)
b = tf.Variable(np.random.rand(1,1), dtype = np.float32)
# In[14]:
C = 1
batch = 100
epochs = 100
# In[15]:
y_cap = tf.matmul(x,w)+b
# In[16]:
reg_loss = 0.5*tf.math.sqrt(tf.reduce_sum(tf.square(w)))
hinge_loss = C*tf.reduce_sum(tf.square(tf.maximum(tf.zeros([batch,1]),1-y*y_cap)))
loss = reg_loss + hinge_loss
# In[17]:
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
# In[18]:
label = tf.sign(y_cap)
# In[19]:
c_pred = tf.equal(y,label)
# In[20]:
accuracy = tf.reduce_sum(tf.cast(c_pred, "float"))
# In[23]:
with tf.Session() as s:
tf.initialize_all_variables().run()
total_steps = epochs*train_size//batch
for step in range(total_steps):
print("step "+str(step)+"/"+str(total_steps)+":",end="\t")
offset = (step*batch)%train_size
b_data = trn_x[offset: (offset+batch)]
b_label = trn_y[offset: (offset+batch)]
train_step.run(feed_dict={x:b_data, y:b_label})
print("loss:"+loss.eval(feed_dict={x:b_data, y:b_label})
详细错误日志:
ValueError Traceback (most recent call last)
<ipython-input-23-57d6ca3d1893> in <module>
7 b_data = trn_x[offset: (offset+batch)]
8 b_label = trn_y[offset: (offset+batch)]
----> 9 train_step.run(feed_dict={x:b_data, y:b_label})
10 print("loss:"+loss.eval(feed_dict={x:b_data, y:b_label}))
/usr/local/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in run(self, feed_dict, session)
2677 none, the default session will be used.
2678 """
-> 2679 _run_using_default_session(self, feed_dict, self.graph, session)
2680
2681 _gradient_registry = registry.Registry("gradient")
/usr/local/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _run_using_default_session(operation, feed_dict, graph, session)
5612 "the operation's graph is different from the session's "
5613 "graph.")
-> 5614 session.run(operation, feed_dict)
5615
5616
/usr/local/lib/python3.7/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
948 try:
949 result = self._run(None, fetches, feed_dict, options_ptr,
--> 950 run_metadata_ptr)
951 if run_metadata:
952 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/usr/local/lib/python3.7/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1147 'which has shape %r' %
1148 (np_val.shape, subfeed_t.name,
-> 1149 str(subfeed_t.get_shape())))
1150 if not self.graph.is_feedable(subfeed_t):
1151 raise ValueError('Tensor %s may not be fed.' % subfeed_t)
ValueError: Cannot feed value of shape (100,) for Tensor 'Placeholder_1:0', which has shape '(?, 1)'
解决方案
正如错误所说,您应该匹配您的占位符和输入数据。我认为您的第二个占位符“y”具有 [?,1] 形状,但“b_label”只有 [100] 形状。您应该在运行网络之前添加此代码。
b_label = np.expand_dims(b_label,axis=-1)
此代码使您的 b_label 的形状为 [100,1],这是与您的占位符匹配的大小。
推荐阅读
- drools - 可以将 Activiti 与 Drools 规则引擎集成
- excel - 在 Excel 列中查找匹配值
- python - Python:与 --data-urlencode 等效的 python 的 requests.posts 命令是什么?
- html - 为什么基本悬停颜色不适用
- powershell - Microsoft Azure Blob 存储更新
- vue-cli - 如何将 `vue ui` 绑定到另一个 IP?(vue-cli 3)
- r - 在R中提取字符串
- makefile - 关于make文件语法理解的初学者问题
- vba - 标记列中给定值的连续值。VBA
- javascript - React Native TypeError:网络请求失败