首页 > 解决方案 > 关于使用 tf.records 文件训练 tensorflow 估计器的问题

问题描述

我正在使用tensorflow estimator文件tf.records进行培训,但我无法解决一些错误。

我的数据集很大......所以我必须改变我的数据集npz->tfrecords。TF 估计器在 numpy 文件数据集上运行良好。

首先,我保存的 tfrecords 文件 shape = (N, 437, 256)。我的 batch_size 是 30。这是我的代码。


def parser(serialized_example):
    """Parses a single tf.Example into image and label tensors."""
    features = tf.parse_single_example(
            serialized_example,
            features={
                'train_input_enc/time': tf.FixedLenFeature([], tf.float32),
                'train_input_enc/fft': tf.FixedLenFeature([], tf.float32),
                'train_input_enc/data_raw': tf.FixedLenFeature([], tf.string),
                'train_output_dec/time': tf.FixedLenFeature([], tf.float32),
                'train_output_dec/fft': tf.FixedLenFeature([], tf.float32),
                'train_output_dec/data_raw': tf.FixedLenFeature([], tf.string),
                'train_target_dec/time': tf.FixedLenFeature([], tf.float32),
                'train_target_dec/fft': tf.FixedLenFeature([], tf.float32),
                'train_target_dec/data_raw': tf.FixedLenFeature([], tf.string)
                })
    inp_enc_spec = tf.decode_raw(features['train_input_enc/data_raw'], tf.float32)
    oup_dec_spec = tf.decode_raw(features['train_output_dec/data_raw'], tf.float32)
    trg_dec_spec = tf.decode_raw(features['train_target_dec/data_raw'], tf.float32)

    inp_enc_spectrogram = tf.reshape(inp_enc_spec, [DEFINES.max_sequence_length, DEFINES.embedding_size])
    oup_dec_spectrogram = tf.reshape(oup_dec_spec, [DEFINES.max_sequence_length, DEFINES.embedding_size])
    trg_dec_spectrogram = tf.reshape(trg_dec_spec, [DEFINES.max_sequence_length, DEFINES.embedding_size])

    return inp_enc_spectrogram, oup_dec_spectrogram, trg_dec_spectrogram


def rearrange(input, output, target):
    features = {"input": input, "output": output}
    return features, target

def train_input_fn(train_input_enc, train_output_dec, train_target_dec, batch_size):

    dataset = tf.data.TFRecordDataset([train_input_enc, train_output_dec, train_target_dec])
    print("input dataset", dataset)

    dataset = dataset.map(map_func=parser)
    print("dataset map with parser",dataset) # <TFRecordDataset shapes: (), types: tf.string>
    dataset = dataset.shuffle(buffer_size=900)
    print("dataset shuffle", dataset) #  <MapDataset shapes: ((437, 256), (437, 256), (437, 256)), types: (tf.float32, tf.float32, tf.float32)>
    dataset = dataset.batch(batch_size = batch_size, drop_remainder=True)
    print("dataset batch", dataset) # <BatchDataset shapes: ((30, 437, 256), (30, 437, 256), (30, 437, 256)), types: (tf.float32, tf.float32, tf.float32)>
    dataset = dataset.map(rearrange)
    print("dataset after rearrange",dataset) # <MapDataset shapes: ({output: (30, 437, 256), input: (30, 437, 256)}, (30, 437, 256)), types: ({output: tf.float32, input: tf.float32}, tf.float32)>
    dataset = dataset.repeat()
    print("dataset repeat",dataset) # <RepeatDataset shapes: ({output: (30, 437, 256), input: (30, 437, 256)}, (30, 437, 256)), types: ({output: tf.float32, input: tf.float32}, tf.float32)>
    iterator = dataset.make_initializable_iterator()
    print("iterator init", iterator) # <tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7f9ed1700128>
#    iterator = dataset.make_one_shot_iterator()
#    print("iterator oneshow",iterator) 

    iterator = iterator.get_next()
    print("iterator",iterator) # ({'output': <tf.Tensor 'IteratorGetNext:1' shape=(30, 437, 256) dtype=float32>, 'input': <tf.Tensor 'IteratorGetNext:0' shape=(30, 437, 256) dtype=float32>}, <t
f.Tensor 'IteratorGetNext:2' shape=(30, 437, 256) dtype=float32>)

    return iterator

def main()

    filename_inp_enc = './train_input_enc.tfrecords'
    filename_oup_dec = './train_output_dec.tfrecords'
    filename_trg_dec = './train_target_dec.tfrecords'

    check_point_path = os.path.join(os.getcwd(), DEFINES.check_point_path)
    os.makedirs(check_point_path, exist_ok=True)

    classifier = tf.estimator.Estimator(
        model_fn=ml.Model,
        model_dir=DEFINES.check_point_path,
        config=estimator_config,
        params={
            ...
        })

    classifier.train(input_fn=lambda: data.train_input_fn(filename_inp_enc, filename_oup_dec, filename_trg_dec, DEFINES.batch_size),steps=DEFINES.train_steps)



错误消息是:

FailedPreconditionError (see above for traceback): GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element.
         [[{{node IteratorGetNext}} = IteratorGetNext[output_shapes=[[30,437,256], [30,437,256], [30,437,256]], output_types=[DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](IteratorV2)]]

当我把

iterator = dataset.make_one_shot_iterator()
return iterator.get_next()

错误信息如下:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Feature: train_output_dec/data_raw (data type: string) is required but could not be found.
         [[{{node ParseSingleExample/ParseSingleExample}} = ParseSingleExample[Tdense=[DT_STRING, DT_FLOAT, DT_FLOAT, DT_STRING, DT_FLOAT, DT_FLOAT, DT_STRING, DT_FLOAT, DT_FLOAT], dense_keys=["train_input_enc/data_raw", "train_input_enc/fft", "train_input_enc/time", "train_output_dec/data_raw", "train_output_dec/fft", "train_output_dec/time", "train_target_dec/data_raw", "train_target_dec/fft", "train_target_dec/time"], dense_shapes=[[], [], [], [], [], [], [], [], []], num_sparse=0, sparse_keys=[], sparse_types=[], _device="/device:CPU:0"](arg0, ParseSingleExample/Const, ParseSingleExample/Const_1, ParseSingleExample/Const_1, ParseSingleExample/Const, ParseSingleExample/Const_1, ParseSingleExample/Const_1, ParseSingleExample/Const, ParseSingleExample/Const_1, ParseSingleExample/Const_1)]]
         [[{{node IteratorGetNext}} = IteratorGetNext[output_shapes=[[30,437,256], [30,437,256], [30,437,256]], output_types=[DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]
         [[{{node IteratorGetNext/_4351}} = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_192_IteratorGetNext", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

我怎么解决这个问题?

标签: iteratortensorflow-datasetstensorflow-estimatortfrecord

解决方案


推荐阅读