iterator - 关于使用 tf.records 文件训练 tensorflow 估计器的问题
问题描述
我正在使用tensorflow estimator
文件tf.records
进行培训,但我无法解决一些错误。
我的数据集很大......所以我必须改变我的数据集npz->tfrecords
。TF 估计器在 numpy 文件数据集上运行良好。
首先,我保存的 tfrecords 文件 shape = (N, 437, 256)。我的 batch_size 是 30。这是我的代码。
def parser(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
features = tf.parse_single_example(
serialized_example,
features={
'train_input_enc/time': tf.FixedLenFeature([], tf.float32),
'train_input_enc/fft': tf.FixedLenFeature([], tf.float32),
'train_input_enc/data_raw': tf.FixedLenFeature([], tf.string),
'train_output_dec/time': tf.FixedLenFeature([], tf.float32),
'train_output_dec/fft': tf.FixedLenFeature([], tf.float32),
'train_output_dec/data_raw': tf.FixedLenFeature([], tf.string),
'train_target_dec/time': tf.FixedLenFeature([], tf.float32),
'train_target_dec/fft': tf.FixedLenFeature([], tf.float32),
'train_target_dec/data_raw': tf.FixedLenFeature([], tf.string)
})
inp_enc_spec = tf.decode_raw(features['train_input_enc/data_raw'], tf.float32)
oup_dec_spec = tf.decode_raw(features['train_output_dec/data_raw'], tf.float32)
trg_dec_spec = tf.decode_raw(features['train_target_dec/data_raw'], tf.float32)
inp_enc_spectrogram = tf.reshape(inp_enc_spec, [DEFINES.max_sequence_length, DEFINES.embedding_size])
oup_dec_spectrogram = tf.reshape(oup_dec_spec, [DEFINES.max_sequence_length, DEFINES.embedding_size])
trg_dec_spectrogram = tf.reshape(trg_dec_spec, [DEFINES.max_sequence_length, DEFINES.embedding_size])
return inp_enc_spectrogram, oup_dec_spectrogram, trg_dec_spectrogram
def rearrange(input, output, target):
features = {"input": input, "output": output}
return features, target
def train_input_fn(train_input_enc, train_output_dec, train_target_dec, batch_size):
dataset = tf.data.TFRecordDataset([train_input_enc, train_output_dec, train_target_dec])
print("input dataset", dataset)
dataset = dataset.map(map_func=parser)
print("dataset map with parser",dataset) # <TFRecordDataset shapes: (), types: tf.string>
dataset = dataset.shuffle(buffer_size=900)
print("dataset shuffle", dataset) # <MapDataset shapes: ((437, 256), (437, 256), (437, 256)), types: (tf.float32, tf.float32, tf.float32)>
dataset = dataset.batch(batch_size = batch_size, drop_remainder=True)
print("dataset batch", dataset) # <BatchDataset shapes: ((30, 437, 256), (30, 437, 256), (30, 437, 256)), types: (tf.float32, tf.float32, tf.float32)>
dataset = dataset.map(rearrange)
print("dataset after rearrange",dataset) # <MapDataset shapes: ({output: (30, 437, 256), input: (30, 437, 256)}, (30, 437, 256)), types: ({output: tf.float32, input: tf.float32}, tf.float32)>
dataset = dataset.repeat()
print("dataset repeat",dataset) # <RepeatDataset shapes: ({output: (30, 437, 256), input: (30, 437, 256)}, (30, 437, 256)), types: ({output: tf.float32, input: tf.float32}, tf.float32)>
iterator = dataset.make_initializable_iterator()
print("iterator init", iterator) # <tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7f9ed1700128>
# iterator = dataset.make_one_shot_iterator()
# print("iterator oneshow",iterator)
iterator = iterator.get_next()
print("iterator",iterator) # ({'output': <tf.Tensor 'IteratorGetNext:1' shape=(30, 437, 256) dtype=float32>, 'input': <tf.Tensor 'IteratorGetNext:0' shape=(30, 437, 256) dtype=float32>}, <t
f.Tensor 'IteratorGetNext:2' shape=(30, 437, 256) dtype=float32>)
return iterator
def main()
filename_inp_enc = './train_input_enc.tfrecords'
filename_oup_dec = './train_output_dec.tfrecords'
filename_trg_dec = './train_target_dec.tfrecords'
check_point_path = os.path.join(os.getcwd(), DEFINES.check_point_path)
os.makedirs(check_point_path, exist_ok=True)
classifier = tf.estimator.Estimator(
model_fn=ml.Model,
model_dir=DEFINES.check_point_path,
config=estimator_config,
params={
...
})
classifier.train(input_fn=lambda: data.train_input_fn(filename_inp_enc, filename_oup_dec, filename_trg_dec, DEFINES.batch_size),steps=DEFINES.train_steps)
错误消息是:
FailedPreconditionError (see above for traceback): GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element.
[[{{node IteratorGetNext}} = IteratorGetNext[output_shapes=[[30,437,256], [30,437,256], [30,437,256]], output_types=[DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](IteratorV2)]]
当我把
iterator = dataset.make_one_shot_iterator()
return iterator.get_next()
错误信息如下:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Feature: train_output_dec/data_raw (data type: string) is required but could not be found.
[[{{node ParseSingleExample/ParseSingleExample}} = ParseSingleExample[Tdense=[DT_STRING, DT_FLOAT, DT_FLOAT, DT_STRING, DT_FLOAT, DT_FLOAT, DT_STRING, DT_FLOAT, DT_FLOAT], dense_keys=["train_input_enc/data_raw", "train_input_enc/fft", "train_input_enc/time", "train_output_dec/data_raw", "train_output_dec/fft", "train_output_dec/time", "train_target_dec/data_raw", "train_target_dec/fft", "train_target_dec/time"], dense_shapes=[[], [], [], [], [], [], [], [], []], num_sparse=0, sparse_keys=[], sparse_types=[], _device="/device:CPU:0"](arg0, ParseSingleExample/Const, ParseSingleExample/Const_1, ParseSingleExample/Const_1, ParseSingleExample/Const, ParseSingleExample/Const_1, ParseSingleExample/Const_1, ParseSingleExample/Const, ParseSingleExample/Const_1, ParseSingleExample/Const_1)]]
[[{{node IteratorGetNext}} = IteratorGetNext[output_shapes=[[30,437,256], [30,437,256], [30,437,256]], output_types=[DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]
[[{{node IteratorGetNext/_4351}} = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_192_IteratorGetNext", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
我怎么解决这个问题?
解决方案
推荐阅读
- javascript - Flagging cells in grid (Minesweeper in plain JavaScript + p5.js framework)
- html - 粘性元素在 Flex 框中不起作用
- python - 朴素递归比记忆递归更快
- python - 如何在 Python 中使用 Selenium 从在滚动时添加 div 的网页中抓取数据?
- elasticsearch - Elasticsearch 无法在 Ubuntu 20.04 中启动服务
- python - 使用 GPU 的 Tensorflow 比预期的要慢
- multithreading - 主题:如何删除静态生命周期要求
- python - 如何解决“NotImplementedError”
- javascript - Vue ReferenceError:未定义 slugify
- scala - 有没有办法让外部模块添加可在内部模块中使用的隐式?