首页 > 解决方案 > 当我尝试将图像数据和标签插入到我的预训练张量流模型(shuffleNet)中时出现关键错误

问题描述

我正在尝试恢复一个名为 sufflenet 的预训练模型,并且它已经被其他人训练过。但是,当我尝试检索模型并恢复 tensorflow 图以训练新数据集(10000 个图像数据集)时,我在终端中收到一条关键错误消息:这是我的代码:

meta_path = './model/model.ckpt-0.meta'
tf.reset_default_graph()
saver = tf.train.import_meta_graph(meta_path)
restored_graph = tf.get_default_graph()

for tensor in restored_graph.get_operations():
    print (tensor.name)


global_step_tensor = restored_graph.get_tensor_by_name('Softmax/prediction:0')
image_input_node = restored_graph.get_tensor_by_name('TFRecordIterator/IteratorGetNext:0')
label_node = restored_graph.get_tensor_by_name('TFRecordIterator/OneShotIterator:0')
loss = restored_graph.get_tensor_by_name('sparse_softmax_cross_entropy/add:0')

# tf.contrib.quantize.create_training_graph(input_graph=restored_graph, quant_delay=2000000)
optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss)
# tf.contrib.quantize.create_training_graph(quant_delay=2000000)


iterations = 100

# run the session
with tf.Session() as sess:
    # restore the saved vairable
    saver.restore(sess, './model/model.ckpt-0')
    # sess.run(optimizer)

    image_files = []
    labels = []

    with open("/mnt/ficussweden/hhzhang/00_Data/04_TF_Autobot/gender/data/train.tsv", 'r') as f:
        line = f.readline()
        while line:
            array = line.rstrip('\n').split()
            image_files.append(array[0])
            labels.append(int(array[1]))
            line = f.readline()

            if len(image_files)>=10001:
                break

            # print(image_files[len(image_files)-1])

    data = []
    data_labels = []
    for i in range(200):
        print("process batch {}-th data \n".format(i))
        batch = []
        batch_label = []
        for j in range(50):
            img_path = image_files[j+i*50]
            img = np.array(Image.open(img_path))
            batch.append(img)
            batch_label.append(labels[j+i*50])

        batch = np.array(batch)
        batch_label = np.array(batch_label)

        data.append(batch)
        data_labels.append(batch_label)
    data = np.array(data)
    data_labels = np.array(data_labels)
    print(data.shape)
    print(data_labels.shape)

    for i in range(iterations):
        for j in range(len(data)):
            batch_data = data[j]
            batch_label = data_labels[j]


            res = sess.run(train_op, feed_dict = {image_input_node: batch_data, label_node: batch_label})
            print(res)

请注意,我将所有图像数据读入一个 numpy 列表,然后根据输入节点将它们放入我的还原图中。但我收到以下错误消息:

Traceback (most recent call last):
  File "restore_graph_train.py", line 127, in <module>
    res = sess.run(train_op, feed_dict = {image_input_node: batch_data, label_node: batch_label})
  File "/mnt/ficusspain/cqli/virtual_env/quantize_model/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 900, in run
    run_metadata_ptr)
  File "/mnt/ficusspain/cqli/virtual_env/quantize_model/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1088, in _run
    subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
  File "/mnt/ficusspain/cqli/virtual_env/quantize_model/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py", line 128, in as_numpy_dtype
    return _TF_TO_NP[self._type_enum]
KeyError: 20

标签: pythontensorflowneural-network

解决方案


推荐阅读