python - 当我尝试将图像数据和标签插入到我的预训练张量流模型(shuffleNet)中时出现关键错误
问题描述
我正在尝试恢复一个名为 sufflenet 的预训练模型,并且它已经被其他人训练过。但是,当我尝试检索模型并恢复 tensorflow 图以训练新数据集(10000 个图像数据集)时,我在终端中收到一条关键错误消息:这是我的代码:
meta_path = './model/model.ckpt-0.meta'
tf.reset_default_graph()
saver = tf.train.import_meta_graph(meta_path)
restored_graph = tf.get_default_graph()
for tensor in restored_graph.get_operations():
print (tensor.name)
global_step_tensor = restored_graph.get_tensor_by_name('Softmax/prediction:0')
image_input_node = restored_graph.get_tensor_by_name('TFRecordIterator/IteratorGetNext:0')
label_node = restored_graph.get_tensor_by_name('TFRecordIterator/OneShotIterator:0')
loss = restored_graph.get_tensor_by_name('sparse_softmax_cross_entropy/add:0')
# tf.contrib.quantize.create_training_graph(input_graph=restored_graph, quant_delay=2000000)
optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss)
# tf.contrib.quantize.create_training_graph(quant_delay=2000000)
iterations = 100
# run the session
with tf.Session() as sess:
# restore the saved vairable
saver.restore(sess, './model/model.ckpt-0')
# sess.run(optimizer)
image_files = []
labels = []
with open("/mnt/ficussweden/hhzhang/00_Data/04_TF_Autobot/gender/data/train.tsv", 'r') as f:
line = f.readline()
while line:
array = line.rstrip('\n').split()
image_files.append(array[0])
labels.append(int(array[1]))
line = f.readline()
if len(image_files)>=10001:
break
# print(image_files[len(image_files)-1])
data = []
data_labels = []
for i in range(200):
print("process batch {}-th data \n".format(i))
batch = []
batch_label = []
for j in range(50):
img_path = image_files[j+i*50]
img = np.array(Image.open(img_path))
batch.append(img)
batch_label.append(labels[j+i*50])
batch = np.array(batch)
batch_label = np.array(batch_label)
data.append(batch)
data_labels.append(batch_label)
data = np.array(data)
data_labels = np.array(data_labels)
print(data.shape)
print(data_labels.shape)
for i in range(iterations):
for j in range(len(data)):
batch_data = data[j]
batch_label = data_labels[j]
res = sess.run(train_op, feed_dict = {image_input_node: batch_data, label_node: batch_label})
print(res)
请注意,我将所有图像数据读入一个 numpy 列表,然后根据输入节点将它们放入我的还原图中。但我收到以下错误消息:
Traceback (most recent call last):
File "restore_graph_train.py", line 127, in <module>
res = sess.run(train_op, feed_dict = {image_input_node: batch_data, label_node: batch_label})
File "/mnt/ficusspain/cqli/virtual_env/quantize_model/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 900, in run
run_metadata_ptr)
File "/mnt/ficusspain/cqli/virtual_env/quantize_model/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1088, in _run
subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
File "/mnt/ficusspain/cqli/virtual_env/quantize_model/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py", line 128, in as_numpy_dtype
return _TF_TO_NP[self._type_enum]
KeyError: 20
解决方案
推荐阅读
- python - 在python中将json数据转换为pandas数据框(列表内的字典)
- google-sheets - 谷歌表格数组公式列结果乘以单元格值
- python - 在 pandas 列中替换为 Python 正则表达式
- c# - 持久字典
无法为自定义结构创建 - html - css 图像/图像:后和背景颜色
- python - 如何检查python中的路径是否是文件?
- ios - Alamofire 堆叠响应
- docker - 在 docker 容器中使用 nginx certbot 时证书无效
- python - 不支持 img 数据类型 = 17
- python - 用音频录制桌面。Python ffmpeg