python - 如何使用预训练网络对新的音频文件进行预测?
问题描述
我需要一种用于声音分类的深度学习算法,它使用预先训练的模型来预测新的音频文件。由于我是 python 新手,我正在研究这个我发现但效果很好的声音分类算法,但目前它输出数据集中每个音频文件的真实类及其预测类。相反,我需要它来预测新音频文件的类别,这些文件根本没有标记,但我未能成功修改代码以按我想要的方式工作。
我目前使用的代码:
from tensorflow.keras.models import load_model
from clean import downsample_mono, envelope
from kapre.time_frequency import STFT, Magnitude, ApplyFilterbank, MagnitudeToDecibel
from sklearn.preprocessing import LabelEncoder
import numpy as np
from glob import glob
import argparse
import os
import pandas as pd
from tqdm import tqdm
def make_prediction(args):
model = load_model(args.model_fn,
custom_objects={'STFT':STFT,
'Magnitude':Magnitude,
'ApplyFilterbank':ApplyFilterbank,
'MagnitudeToDecibel':MagnitudeToDecibel})
wav_paths = glob('{}/**'.format(args.src_dir), recursive=True)
wav_paths = sorted([x.replace(os.sep, '/') for x in wav_paths if '.wav' in x])
classes = sorted(os.listdir(args.src_dir))
labels = [os.path.split(x)[0].split('/')[-1] for x in wav_paths]
le = LabelEncoder()
y_true = le.fit_transform(labels)
results = []
for z, wav_fn in tqdm(enumerate(wav_paths), total=len(wav_paths)):
rate, wav = downsample_mono(wav_fn, args.sr)
mask, env = envelope(wav, rate, threshold=args.threshold)
clean_wav = wav[mask]
step = int(args.sr*args.dt)
batch = []
for i in range(0, clean_wav.shape[0], step):
sample = clean_wav[i:i+step]
sample = sample.reshape(-1, 1)
if sample.shape[0] < step:
tmp = np.zeros(shape=(step, 1), dtype=np.float32)
tmp[:sample.shape[0],:] = sample.flatten().reshape(-1, 1)
sample = tmp
batch.append(sample)
X_batch = np.array(batch, dtype=np.float32)
y_pred = model.predict(X_batch)
y_mean = np.mean(y_pred, axis=0)
y_pred = np.argmax(y_mean)
real_class = os.path.dirname(wav_fn).split('/')[-1]
print('Actual class: {}, Predicted class: {}'.format(real_class, classes[y_pred]))
results.append(y_mean)
np.save(os.path.join('logs', args.pred_fn), np.array(results))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Audio Classification Training')
parser.add_argument('--model_fn', type=str, default='models/lstm.h5',
help='model file to make predictions')
parser.add_argument('--pred_fn', type=str, default='y_pred',
help='fn to write predictions in logs dir')
parser.add_argument('--src_dir', type=str, default='wavfiles',
help='directory containing wavfiles to predict')
parser.add_argument('--dt', type=float, default=1.0,
help='time in seconds to sample audio')
parser.add_argument('--sr', type=int, default=16000,
help='sample rate of clean audio')
parser.add_argument('--threshold', type=str, default=20,
help='threshold magnitude for np.int16 dtype')
args, _ = parser.parse_known_args()
make_prediction(args)
解决方案
在代码部分
wav_paths = sorted([x.replace(os.sep, '/') for x in wav_paths if '.wav' in x])
classes = sorted(os.listdir(args.src_dir))
labels = [os.path.split(x)[0].split('/')[-1] for x in wav_paths]
le = LabelEncoder() # delete
y_true = le.fit_transform(labels) # delete
results = []
删除标有 '# delete 的 2 行
那么在本节
X_batch = np.array(batch, dtype=np.float32)
y_pred = model.predict(X_batch)
y_mean = np.mean(y_pred, axis=0)
y_pred = np.argmax(y_mean)
real_class = os.path.dirname(wav_fn).split('/')[-1]
print('Actual class: {}, Predicted class: {}'.format(real_class, classes[y_pred]))
results.append(y_mean)
np.save(os.path.join('logs', args.pred_fn), np.array(results))
改写如下
X_batch = np.array(batch, dtype=np.float32)
y_pred = model.predict(X_batch)
y_mean = np.mean(y_pred, axis=0)
y_pred = np.argmax(y_mean)
print('Predicted class: {}'.format(classes[y_pred]))
results.append(y_mean)
# np.save(os.path.join('logs', args.pred_fn), np.array(results))
应该这样做
推荐阅读
- node.js - 如果数组包含值,猫鼬会查找文档
- python - Python3 -- 256 AES 解密和加密 + 蛮力是我试图实现的目标。为什么我会遇到错误?
- pyspark - 如何将 SnowflakeCursor 转换为 pySpark 数据框
- python - Pandas DataFrame 上的组特定计算
- mocha.js - 摩卡记者不支持并行模式,有吗?
- python - 设置对象没有属性 DATABASE
- docker - 如何在浏览器中查看运行在容器上的 Web 应用服务器的结果
- google-apps-script - 有没有办法使用谷歌应用脚本的 GMAIL API 来更新 gmail 的签名“在回复/转发使用时”?
- c# - 是否可以将两个数组作为单个命令相乘以提高代码性能?
- javascript - 为什么按照示例教程解析错误意外令牌<出现在PHP Javascript的以下代码行中?