tensorflow - InvalidArgumentError:尝试读取字符串时数据太短 [[{{node DecodeWav}}]] [Op:IteratorGetNext]
问题描述
我一直在尝试使用 tensorflow 实现来学习音频分类,但是当我在另一个数据集上对其进行测试时遇到了错误。
代码:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import glob
import os
filenames = tf.io.gfile.glob("C:/Users/Natha/Downloads/voice_emotions/audio_speech_actors_01-24/*/*")
filenames = tf.random.shuffle(filenames)
print("Number of files: ", len(filenames))
print("Sample file: ", filenames[0])
train_files = filenames[:1200]
val_files = filenames[1200:1300]
test_files = filenames[1300:]
def decode_audio(audio_binary):
audio, _ = tf.audio.decode_wav(audio_binary)
return tf.squeeze(audio, axis=-1)
def get_label(file_path):
parts = tf.strings.split(file_path, "-")
return parts[-4]
def get_waveform_and_label(file_path):
label = get_label(file_path)
audio_binary = tf.io.read_file(file_path)
waveform = decode_audio(audio_binary)
return waveform, label
AUTOTUNE = tf.data.AUTOTUNE
files_ds = tf.data.Dataset.from_tensor_slices(train_files)
waveform_ds = files_ds.map(get_waveform_and_label, num_parallel_calls=AUTOTUNE)
rows = 3
cols = 3
n = rows * cols
fig, axes = plt.subplots(rows, cols, figsize=(10, 12))
for i, (audio, label) in enumerate(waveform_ds.take(n)):
r = i // cols
c = i % cols
ax = axes[r][c]
ax.plot(audio.numpy())
ax.set_yticks(np.arange(-1.2, 1.2, 0.2))
label = label.numpy().decode('utf-8')
ax.set_title(label)
plt.show()
错误:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-108-8a7064627ebf> in <module>
3 n = rows * cols
4 fig, axes = plt.subplots(rows, cols, figsize=(10, 12))
----> 5 for i, (audio, label) in enumerate(waveform_ds.take(n)):
6 r = i // cols
7 c = i % cols
~\anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py in __next__(self)
759 def __next__(self):
760 try:
--> 761 return self._next_internal()
762 except errors.OutOfRangeError:
763 raise StopIteration
~\anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py in _next_internal(self)
742 # to communicate that there is no more data to iterate over.
743 with context.execution_mode(context.SYNC):
--> 744 ret = gen_dataset_ops.iterator_get_next(
745 self._iterator_resource,
746 output_types=self._flat_output_types,
~\anaconda3\lib\site-packages\tensorflow\python\ops\gen_dataset_ops.py in iterator_get_next(iterator, output_types, output_shapes, name)
2725 return _result
2726 except _core._NotOkStatusException as e:
-> 2727 _ops.raise_from_not_ok_status(e, name)
2728 except _core._FallbackException:
2729 pass
~\anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in raise_from_not_ok_status(e, name)
6939 message = e.message + (" name: " + name if name is not None else "")
6940 # pylint: disable=protected-access
-> 6941 six.raise_from(core._status_to_exception(e.code, message), None)
6942 # pylint: enable=protected-access
6943
~\anaconda3\lib\site-packages\six.py in raise_from(value, from_value)
InvalidArgumentError: Data too short when trying to read string
[[{{node DecodeWav}}]] [Op:IteratorGetNext]
我已经确保数据集中的所有项目的大小都正确,但错误可能仍然是数据集。我见过许多其他人遇到此错误,但他们的解决方案都没有帮助。
版本:Tensorflow:2.6.0 数据集: https ://www.kaggle.com/uwrfkaggler/ravdess-emotional-speech-audio
解决方案
更新:我能够找到解决该问题的方法。我使用此代码找到了有效和损坏的文件。
corrupted_files = []
valid_files = []
wave = iter(waveform_ds)
for i in range(1440):
try:
for index, tensor in enumerate(wave.get_next()):
valid_files.append(filenames[i])
except:
corrupted_files.append(filenames[i])
不幸的是,我的大部分文件都已损坏,所以希望有办法摆脱这个问题而不是避免它。
推荐阅读
- xpath - Google Sheets IMPORTXML 的 XPath - 在 h1 和最后一个 p 标签之间找到的所有文本(包括)
- java - 矩阵打印不正确
- java - 在HashMap中添加新的键值对时,插入顺序会发生什么变化?
- javascript - 如何验证 MarkLogic 中的 dateTime 输入?
- python - 用 openpyxl 覆盖整个列
- python - 在熊猫中使用单独的时间戳列进行最小/最大分组
- java - 设置 bean 属性“mongoOperations”时无法解析对 bean“mongoTemplate”的引用
- java - [Ljava.lang.String;@7c5da1bepackage.name.user.categories.DAO
- python - FiPy Hollow 圆柱形 O'Grid 3D 生长扫描网格划分 [CylindricalGrid2D]
- javascript - 移除节点并重绘力导向图