python - Keras 没有验证数据集
问题描述
我正在尝试按照语音识别权重和偏差的教程进行操作:
https://github.com/lukas/ml-class/tree/master/videos/cnn-audio
我做了教程中的所有操作,但收到以下错误消息:
wandb: WARNING No validation_data set, pass a generator to the callback.
您也可以通过 GitHub 链接查找代码,它非常相似(我只更改了标签的名称)。
预处理.py:
import librosa
import os
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
import numpy as np
from tqdm import tqdm
DATA_PATH = "./data/"
# Input: Folder Path
# Output: Tuple (Label, Indices of the labels, one-hot encoded labels)
def get_labels(path=DATA_PATH):
labels = os.listdir(path)
label_indices = np.arange(0, len(labels))
return labels, label_indices, to_categorical(label_indices)
# convert file to wav2mfcc
# Mel-frequency cepstral coefficients
def wav2mfcc(file_path, n_mfcc=20, max_len=11):
wave, sr = librosa.load(file_path, mono=True, sr=None)
wave = np.asfortranarray(wave[::3])
mfcc = librosa.feature.mfcc(wave, sr=16000, n_mfcc=n_mfcc)
# If maximum length exceeds mfcc lengths then pad the remaining ones
if (max_len > mfcc.shape[1]):
pad_width = max_len - mfcc.shape[1]
mfcc = np.pad(mfcc, pad_width=((0, 0), (0, pad_width)), mode='constant')
# Else cutoff the remaining parts
else:
mfcc = mfcc[:, :max_len]
return mfcc
def save_data_to_array(path=DATA_PATH, max_len=11, n_mfcc=20):
labels, _, _ = get_labels(path)
for label in labels:
# Init mfcc vectors
mfcc_vectors = []
wavfiles = [path + label + '/' + wavfile for wavfile in os.listdir(path + '/' + label)]
for wavfile in tqdm(wavfiles, "Saving vectors of label - '{}'".format(label)):
mfcc = wav2mfcc(wavfile, max_len=max_len, n_mfcc=n_mfcc)
mfcc_vectors.append(mfcc)
np.save(label + '.npy', mfcc_vectors)
def get_train_test(split_ratio=0.6, random_state=42):
# Get available labels
labels, indices, _ = get_labels(DATA_PATH)
# Getting first arrays
X = np.load(labels[0] + '.npy')
y = np.zeros(X.shape[0])
# Append all of the dataset into one single array, same goes for y
for i, label in enumerate(labels[1:]):
x = np.load(label + '.npy')
X = np.vstack((X, x))
y = np.append(y, np.full(x.shape[0], fill_value= (i + 1)))
assert X.shape[0] == len(y)
return train_test_split(X, y, test_size= (1 - split_ratio), random_state=random_state, shuffle=True)
def prepare_dataset(path=DATA_PATH):
labels, _, _ = get_labels(path)
data = {}
for label in labels:
data[label] = {}
data[label]['path'] = [path + label + '/' + wavfile for wavfile in os.listdir(path + '/' + label)]
vectors = []
for wavfile in data[label]['path']:
wave, sr = librosa.load(wavfile, mono=True, sr=None)
# Downsampling
wave = wave[::3]
mfcc = librosa.feature.mfcc(wave, sr=16000)
vectors.append(mfcc)
data[label]['mfcc'] = vectors
return data
def load_dataset(path=DATA_PATH):
data = prepare_dataset(path)
dataset = []
for key in data:
for mfcc in data[key]['mfcc']:
dataset.append((key, mfcc))
return dataset[:100]
# print(prepare_dataset(DATA_PATH))
音频.ipynb:
from preprocess import *
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D, LSTM
from tensorflow.keras.utils import to_categorical
import wandb
from wandb.keras import WandbCallback
import matplotlib.pyplot as plt
wandb.init()
config = wandb.config
config.max_len = 11
config.buckets = 20
# Save data to array file first
save_data_to_array(max_len=config.max_len, n_mfcc=config.buckets)
labels=["off", "on", "stop"]
# Loading train set and test set
X_train, X_test, y_train, y_test = get_train_test()
# Feature dimension
channels = 1
config.epochs = 100
config.batch_size = 100
num_classes = 3
X_train = X_train.reshape(X_train.shape[0], config.buckets, config.max_len, channels)
X_test = X_test.reshape(X_test.shape[0], config.buckets, config.max_len, channels)
plt.imshow(X_train[100, :, :, 0])
print(y_train[100])
y_train_hot = to_categorical(y_train)
y_test_hot = to_categorical(y_test)
X_train = X_train.reshape(X_train.shape[0], config.buckets, config.max_len)
X_test = X_test.reshape(X_test.shape[0], config.buckets, config.max_len)
model = Sequential()
model.add(Flatten(input_shape=(config.buckets, config.max_len)))
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss="categorical_crossentropy",
optimizer="adam",
metrics=['accuracy'])
wandb.init()
model.fit(X_train, y_train_hot, epochs=config.epochs, validation_data=(X_test, y_test_hot), callbacks=[WandbCallback(data_type="image", labels=labels)])
这对我来说是一种逻辑,为什么会发生这个错误,因为我在代码中找不到它说这是验证数据的位置的地方。我也想使用我下载的验证数据,但不知道如何。
您能帮我解决验证数据的错误吗?
wandb 版本:0.10.30
解决方案
推荐阅读
- django-rest-framework - 错误 401:invalid_client 未找到 OAuth 客户端
- ios - iOS:BGProcessingTaskRequest 因 CPU 使用率高而被杀死
- swift - 如何更改导航栏和后退按钮颜色 iOS 15
- java - 如何使用 Jackson 创建不可变 DTO 但没有注释?
- javascript - 在 Ckeditor 中,如果我输入文本或退格输入复选框将取消选中并且值将更改
- c++ - Android NDK Crashanalytics 将日志发送到 Firebase 时出错。(C++)
- flutter - 如何在颤动中解决(操作系统错误:不允许操作,errno = 1)
- pandas - Databricks 火花雪花 dataframe.toPandas() 占用更多空间和时间
- python - Serializers.validated_data 字段随 DRF 中的源值而更改
- c++ - IIS 在回收期间是否取消/注册 C++ COM 库?