python - 如何在网络中使用 RNN 单元?
问题描述
我正在尝试在我的网络中使用自定义的 RNN 单元。我从 Keras 的 RNN 单元示例开始,其中 RNN 单元定义为 MinimalRNNCell。当我试图在我的循环网络中使用定义的单元时,通过用自定义的 RNN 单元替换我之前使用的 simpleRNN,但出现了这个错误:
ValueError:一个操作有None
梯度。请确保您的所有操作都定义了渐变(即可微分)。没有梯度的常见操作:K.argmax、K.round、K.eval。
class MinimalRNNCell(Layer):
def __init__(self, units, **kwargs):
self.units = units
self.state_size = units
super(MinimalRNNCell, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
def call(self, inputs, states):
prev_output = states[0]
h = K.dot(inputs, self.kernel)
output = h + K.dot(prev_output, self.recurrent_kernel)
return output, [output]
# Let's use this cell in a RNN layer:
cell = MinimalRNNCell(32)
x = keras.Input((None, n_f))
layer = RNN(cell)
y = layer(x)
# prepare sequence
length = 10
n_f = 10
# define LSTM configuration
n_neurons = length
n_batch = length
n_epoch = 1000
# create LSTM
model = Sequential()
#model.add(SimpleRNN(n_neurons, input_shape=(length, n_f) ,return_sequences=True))
model.add(RNN(cell, input_shape=(length, n_f) ,return_sequences=True))
model.add(TimeDistributed(Dense(n_neurons)))
model.add(Activation('relu'))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')
print(model.summary())
# train LSTM
ES = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience= int(n_epoch/2))
history = model.fit(X_train, y_train, validation_data= (X_Val,y_Val),epochs=n_epoch, batch_size=n_batch, verbose=2 , callbacks=[ES])
解决方案
推荐阅读
- javascript - for循环语句中的var关键字解释
- jenkins - kubectl 命令在 VM 中运行但不在 jenkins 管道中运行
- reactjs - React - 如何处理材料ui的datagrid表的特定列的单元格单击?
- uipath - 如何清除 uipath 中的 chrome 浏览器历史记录?
- r - R plotly subplot 共享 X 轴和共享 Y 轴
- ionic-framework - 离子 NFC 依赖项
- elixir - 如何查询月份和工作日?
- oracle - 无法从 obiee 12c 连接 essbase
- python - 为 Django 多对多关系定义连接表是最佳实践吗?
- php - 如何从json读取数组