python - Keras 验证生成器评估指标
问题描述
我正在尝试对一维信号数据集进行二进制分类。我正在使用keras 2.2.4
backed bytensorflow 1.12.0
来完成这项任务。我需要在我的验证数据上应用一个自定义指标,所以我创建了这个Metrics
类。这是我正在使用的主要代码:
train_gen = load.data_generator(batch_size, preproc, *train)
dev_gen = load.data_generator(batch_size, preproc, *dev)
valid_metrics = Metrics(dev_gen, len(dev[0]) // batch_size, batch_size)
model.fit_generator(train_gen,
steps_per_epoch=len(train[0]) // batch_size,
epochs=MAX_EPOCHS,
validation_data=dev_gen,
validation_steps=len(dev[0]) // batch_size,
callbacks=[valid_metrics, checkpointer, reduce_lr, tensorboard])
class Metrics(keras.callbacks.Callback):
def __init__(self, val_data, step, batch_size=20):
self.validation_data = val_data
self.batch_size = batch_size
self.validation_step = step
print('validation_step ' + str(step))
def on_train_begin(self, logs={}):
self._rocdata_perclass = []
self._prdata_perclass = []
self._accdata = []
def on_epoch_end(self, batch, logs={}):
for batch_index in range(self.validation_step):
xVal, yVal = next(self.validation_data)
predictions = self.model.predict(xVal)
print(predictions.shape)
print(yVal.shape)
****** Training output*****
1/12 [=>............................] - ETA: 2:27 - loss: 1.1788 - acc: 0.5113
x shape (16, 38400, 1)
y shape (16, 150, 2)
2/12 [====>.........................] - ETA: 1:11 - loss: 2.1305 - acc: 0.4668
x shape (16, 38144, 1)
y shape (16, 149, 2)
3/12 [======>.......................] - ETA: 44s - loss: 2.0981 - acc: 0.4302
x shape (16, 38400, 1)
y shape (16, 150, 2)
4/12 [=========>....................] - ETA: 30s - loss: 1.8517 - acc: 0.3823
x shape (16, 38400, 1)
y shape (16, 150, 2)
5/12 [===========>..................] - ETA: 21s - loss: 1.6540 - acc: 0.4184
x shape (16, 38144, 1)
y shape (16, 149, 2)
6/12 [==============>...............] - ETA: 16s - loss: 1.5060 - acc: 0.4350
x shape (16, 38912, 1)
y shape (16, 152, 2)
7/12 [================>.............] - ETA: 12s - loss: 1.4447 - acc: 0.3977
x shape (16, 38400, 1)
y shape (16, 150, 2)
8/12 [===================>..........] - ETA: 8s - loss: 1.3394 - acc: 0.4423
x shape (16, 41472, 1)
y shape (16, 162, 2)
9/12 [=====================>........] - ETA: 6s - loss: 1.2576 - acc: 0.4718
x shape (16, 37888, 1)
y shape (16, 148, 2)
10/12 [========================>.....] - ETA: 3s - loss: 1.1878 - acc: 0.5138
x shape (16, 39424, 1)
y shape (16, 154, 2)
11/12 [==========================>...] - ETA: 1s - loss: 1.1355 - acc: 0.5395
x shape (16, 39168, 1)
y shape (16, 153, 2)
正如您在上面的训练输出中看到的那样,每个批次输入 (x) 和输出 (y) 的形状彼此不同,因为我将数据填充为具有相同的长度。一个信号有n classes of dimension 2
。我的问题与我需要使用Metrics
该类对验证数据进行的评估有关。由于我对每个批次的验证也有不同的输入和输出形状,我如何使用 ROC 曲线下的面积一次评估整个验证数据?正如您在上面的代码中看到的,我只能使用批量验证,而我需要整个验证数据,以便能够对其进行一次评估。
******* Validation Metric output ********
x shape (16, 41984, 1)
y shape (16, 164, 2)
x shape (16, 38400, 1)
y shape (16, 150, 2)
提前致谢。
解决方案
推荐阅读
- r - R:编写带有可选参数的函数
- php - php 动态 POST[] 值
- laravel - 只有一个值是 get 方法中使用explode 变量
- sh - 通过 root 为 oracle 用户执行时脚本不提示输入变量
- sql-server - 如何检查表中是否存在特定记录
- python - Python Paramiko,PermissionError:[Errno 13] 从远程服务器获取文件时权限被拒绝
- db2 - Db2 z/OS 替代 Db2 LUW 的 ROUTINE_SCHEMA 或 ROUTINE_SPECIFIC_NAME 全局变量
- angular - 串联组合请求 - Angular 7
- amazon-web-services - AWS API Gateway “x-amzn-requestid” - 可以在 .net 核心中捕获吗
- apache - Cookie 规范设置为过时的 RFC2109spec RFC2965spec 和 apache HttpClient4.4.4 中的 netscape 规范