python - 使用 Keras 在 RNN 中生成消失和爆炸梯度问题
问题描述
我了解 Vanilla RNN 中的梯度消失和爆炸问题以及为什么会发生这种情况。但是,我想有目的地创建这个问题以便更好地理解。我从https://www.datatechnotes.com/2018/12/rnn-example-with-keras-simplernn-in.html获取了以下代码。
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense, SimpleRNN
# convert into dataset matrix
def convertToMatrix(data, step):
X, Y =[], []
for i in range(len(data)-step):
d=i+step
X.append(data[i:d,])
Y.append(data[d,])
return np.array(X), np.array(Y)
step = 4
N = 1000
Tp = 800
t=np.arange(0,N)
x=np.sin(0.02*t)+2*np.random.rand(N)
df = pd.DataFrame(x)
df.head()
plt.plot(df)
plt.show()
values=df.values
train,test = values[0:Tp,:], values[Tp:N,:]
# add step elements into train and test
test = np.append(test,np.repeat(test[-1,],step))
train = np.append(train,np.repeat(train[-1,],step))
trainX,trainY =convertToMatrix(train,step)
testX,testY =convertToMatrix(test,step)
trainX = np.reshape(trainX, (trainX.shape[0], 1, trainX.shape[1]))
testX = np.reshape(testX, (testX.shape[0], 1, testX.shape[1]))
model = Sequential()
model.add(SimpleRNN(units=32, input_shape=(1,step), activation="relu"))
model.add(Dense(8, activation="relu"))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='rmsprop')
model.summary()
model.fit(trainX,trainY, epochs=100, batch_size=16, verbose=2)
trainPredict = model.predict(trainX)
testPredict= model.predict(testX)
predicted=np.concatenate((trainPredict,testPredict),axis=0)
trainScore = model.evaluate(trainX, trainY, verbose=0)
print(trainScore)
我应该如何修改此代码以创建此问题?谢谢你。
解决方案
梯度消失是我们使用 sigmoid 激活函数时的问题。如果更改relu
为sigmoid
,可能会遇到梯度消失问题。
model = Sequential()
model.add(SimpleRNN(units=32, input_shape=(1,step), activation="sigmoid"))
model.add(Dense(8, activation="sigmoid"))
推荐阅读
- python - 如何在移动设备中移植 Django Web 应用程序(iOS 和 Android)
- typescript - 在 TypeScript 检查之前修改 prop 值
- apollo - 阿波罗联盟验证另一个子图中的输入字段
- go - 在 go 中复制指针内容会导致不必要的开销?
- java - 提供的 javaHome 似乎无效。我找不到 java 可执行文件。尝试位置:C:\Program Files\Java\jdk-17\bin\java.exe
- python - ThreadPoolExecutor 是否保证将 N 个任务均匀分布在 N 个线程上?
- java - 处理 JPMS 模块和与 java 8 的兼容性的正确方法?
- java - Java 短日期字符串转换为 ZonedDateTime
- visual-studio-code - 即使没有安装,如何在 Visual Studio Code 中禁用 Rickroll?
- javascript - Vue.js:手风琴在 for 循环中不起作用