python - InvalidArgumentError:reshape 的输入是一个值为 0 的张量,但请求的形状有 54912
问题描述
非常初学者的问题,我希望没问题
我正在尝试使用 MAPS 数据集从 GitHub 训练此模型,并使用此代码为训练集创建了新的 .tfrecords。它基于此处的代码,但我更改了一些内容以便为不同的输入让路(另一个 MIDI 文件,我只是将其称为“tempo MIDI”)。
def create_train_set(tempopath, train_list, outdir, min_length, max_length):
# train_list = list of wav paths selected for
train_file_pairs = []
# find matching midi files
for wav_path in train_list:
midi_file = ''
tempo_midi_file = ''
if os.path.isfile(wav_path + '.mid'):
midi_file = wav_path + '.mid'
if os.path.isfile(wav_path + '.midi'):
midi_file = wav_path + '.midi'
if os.path.isfile(tempopath + os.path.basename(wav_path) + '_tempo.mid'):
tempo_midi_file = tempopath + os.path.basename(wav_path) + '_tempo.mid'
if os.path.isfile(tempopath + os.path.basename(wav_path) + '_tempo.midi'):
tempo_midi_file = tempopath + os.path.basename(wav_path) + '_tempo.midi'
wav_file = wav_path + '.wav'
train_file_pairs.append((wav_file, midi_file, tempo_midi_file))
train_output_name = os.path.join(outdir, 'train.tfrecord')
with tf.python_io.TFRecordWriter(train_output_name) as writer:
for idx, pair in enumerate(train_file_pairs):
print('{} of {}: {}'.format(idx, len(train_file_pairs), pair[0]))
# load the wav data
wav_data = tf.gfile.Open(pair[0], 'rb').read()
# load the midi data and convert to a notesequence
ns = midi_io.midi_file_to_note_sequence(pair[1])
tempo = midi_io.midi_file_to_note_sequence(pair[2])
# aldu = audio_label_data_utils.py
for example in aldu.process_record(
wav_data, ns, tempo, pair[0], min_length, max_length,
sample_rate):
writer.write(example.SerializeToString())
使用 tf.Example 如下:
example = tf.train.Example(
features=tf.train.Features(
feature={
'id':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[example_id.encode('utf-8')])),
'sequence':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[ns.SerializeToString()])),
'audio':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[wav_data])),
'tempo':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[velocity_range.SerializeToString()])),
'velocity_range':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[velocity_range.SerializeToString()])),
}))
但是,当我尝试训练模型时,我收到此错误消息(我用打印行标记了 py 脚本,所以我知道一切都在哪里):
Running wav_to_spec from data.py
Running _wav_to_mel in data.py
Running wav_to_num_frames from data.py
Running wav_to_spec from data.py
Running _wav_to_mel in data.py
Running wav_to_num_frames from data.py
E0611 07:56:55.419340 8436 error_handling.py:70] Error recorded from training_loop: Input to reshape is a tensor with 0 values, but the requested shape has 54912
[[{{node Reshape_8}}]]
[[IteratorGetNext]]
I0611 07:56:55.420338 8436 error_handling.py:96] training_loop marked as finished
W0611 07:56:55.421335 8436 error_handling.py:130] Reraising captured error
Traceback (most recent call last):
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1356, in _do_call
return fn(*args)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1341, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1429, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 0 values, but the requested shape has 54912
[[{{node Reshape_8}}]]
[[IteratorGetNext]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "onsets_frames_transcription_train.py", line 128, in <module>
console_entry_point()
File "onsets_frames_transcription_train.py", line 124, in console_entry_point
tf.app.run(main)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\platform\app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\absl\app.py", line 300, in run
_run_main(main, args)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\absl\app.py", line 251, in _run_main
sys.exit(main(argv))
File "onsets_frames_transcription_train.py", line 120, in main
additional_trial_info=additional_trial_info)
File "onsets_frames_transcription_train.py", line 95, in run
num_steps=FLAGS.num_steps)
File "C:\Users\User\magenta\magenta\models\onsets_frames_transcription\train_util.py", line 134, in train
estimator.train(input_fn=transcription_data, max_steps=num_steps)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\tpu\tpu_estimator.py", line 2876, in train
rendezvous.raise_errors()
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\tpu\error_handling.py", line 131, in raise_errors
six.reraise(typ, value, traceback)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\six.py", line 693, in reraise
raise value
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\tpu\tpu_estimator.py", line 2871, in train
saving_listeners=saving_listeners)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 367, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1158, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1192, in _train_model_default
saving_listeners)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1484, in _train_with_estimator_spec
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 754, in run
run_metadata=run_metadata)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1252, in run
run_metadata=run_metadata)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1353, in run
raise six.reraise(*original_exc_info)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\six.py", line 693, in reraise
raise value
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1338, in run
return self._sess.run(*args, **kwargs)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1411, in run
run_metadata=run_metadata)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1169, in run
return self._sess.run(*args, **kwargs)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 950, in run
run_metadata_ptr)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1173, in _run
feed_dict_tensor, options, run_metadata)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1350, in _do_run
run_metadata)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1370, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 0 values, but the requested shape has 54912
[[{{node Reshape_8}}]]
[[IteratorGetNext]]
由此,我认为问题出在 wav_to_num_frames 但这是唯一的代码。
def wav_to_num_frames(wav_audio, frames_per_second):
"""Transforms a wav-encoded audio string into number of frames."""
print("Running wav_to_num_frames from data")
w = wave.open(six.BytesIO(wav_audio))
return np.int32(w.getnframes() / w.getframerate() * frames_per_second)
当我尝试使用使用原始代码创建的 tfrecords 训练模型时,我没有得到这个问题,所以我不知道出了什么问题。
解决方案
事实证明,问题不在于创建的 .tfrecords 本身,而在于我为新添加的数据分配的张量的大小。不过,没有具体的答案,因为它是针对这种情况的。
推荐阅读
- python - rtsp 方法设置失败:461 客户端错误
- html - 因此
和其他元素有自己的填充
- php - 获取最多 n 级的 XML 标记和属性
- kotlin - 我如何在 kotlin 中与超类建立一对一的关系
- javascript - 如何替换包含不固定字符的字符串中的某些字符
- c# - XNA/Monogame 中的 Ray 视线
- json - 在 Swift 4 中访问嵌套 JSON 字符串的成员
- ios - 从 Openweathermap IOS 4 Xcode 10 获取 json 数据后 UI 不会更新
- c# - Json.NET 从自动属性初始化器获取默认值
- java - 更改树根后获得高度的最快方法是什么?