python - 我的 RNN 是否仅在 1 个或 2 个样本上进行训练?
问题描述
我创建了一个包含 7 个单元的 LSTM-RNN。它减少了损失,但准确度保持为零。在我看到 keras 训练控制台输出之前,我一直无法找出原因。以下是最新训练运行的示例。
Epoch 500/500
2/2 [==============================] - 0s 13ms/step - loss: 0.1505 - accuracy: 0.0000e+00
2/2 是否意味着只在两个样本上进行训练?我有 7168 个数据点,我的批量大小明确表示为 7168,那么为什么会发生这种情况?下面是我的代码
import pandas
import scipy.io as loader
import tensorflow as tf
import keras
import numpy
import time
import math
from tensorflow.keras.datasets import imdb
from tensorflow.keras.layers import Embedding, Dense, LSTM
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.preprocessing.sequence import pad_sequences
additional_metrics = ['accuracy']
loss_function = BinaryCrossentropy()
number_of_epochs = 500
optimizer = SGD()
validation_split = 0.20
verbosity_mode = 1
mini = 0
maxi = 0
mean = 0
"""
"""
def myfunc(arg):
global mini, maxi, mean
return (arg - mean) / (maxi - mini)
# k = 0
cgm = numpy.load('cgm_train_new.npy')
labels = numpy.load('labels_train_new.npy')
labs = list()
cgm_flat = cgm.flatten()
mini = min(cgm_flat)
maxi = max(cgm_flat)
mean = sum(cgm_flat) / len(cgm_flat)
cgm = numpy.apply_along_axis(myfunc, 0, cgm)
for each in labels:
# suma = suma + sum(each)
if each[-1] == 1: labs.append(.99)
else: labs.append(.01)
RNNmodel = Sequential()
RNNmodel.add(LSTM(7, activation='tanh'))
RNNmodel.add(Dense(1, activation='sigmoid'))
RNNmodel.compile(optimizer=optimizer, loss=loss_function, metrics=additional_metrics)
cgm_rs = numpy.reshape(cgm, [len(cgm), 7, 1])
ans = numpy.reshape(labs, [len(labs), 1, 1])
history = RNNmodel.fit(
cgm_rs,
ans,
batch_size=7168,
epochs=number_of_epochs)#,
# verbose=verbosity_mode)#,
# validation_split=validation_split)
tf.keras.utils.plot_model(
RNNmodel,
to_file="RNNmodel.png")
answers = RNNmodel.predict(cgm_rs)
# for each in answers:
# print(each)
解决方案
我已经理解了我的错误。不需要任何答案。谢谢你。
推荐阅读
- sql - 在 Oracle 中将值舍入到小数点到 2 以进行查询
- postgresql - 如何使用与“postgres”不同的用户名登录 PostgreSQL 以及如何加密特定列?
- sql - Sql Server 使用 CONTAINS 函数作为选择语句中的列
- sql - 在查询字符串中使用 '
- pyspark - 如何将一个 6 位数字拆分为一列 4 位和一列 2 位(例如:201452 分为 2014 和 52)
- css - 如何在 .Net Framework WebAPI 项目中使用 npm/yarn 安装?
- java - 存储大量资金总额和内存/存储影响 - BigDecimal vs Integer 和最佳实践?
- mysql - 根据数据库表的其他行中的值设置一列的值
- javascript - 数组和对象中 eslint 的自定义缩进规则
- ios - 无法从 Xcode 将应用加载到应用商店。有什么办法可以解决这个问题吗?