python - 初始状态 lstm 编码器 解码器 keras
问题描述
我正在尝试构建一个 LSTM 编码器解码器,我的主要目标是解码器的初始状态与编码器相同。我从这里找到了下面的代码,并试图将其附加到我的案例中。我有一个形状为 (1000,20,1) 的数据。我希望编码器解码器在输出中将我的输入返回给我。即使我理解错误,我也不知道如何更正它正在工作的代码。当我尝试运行它时,我收到以下错误:
The model expects 2 input arrays, but only received one array. Found:
array with shape (10000, 20, 1)
from keras.models import Model
from keras.layers import Input
from keras.layers import LSTM
from keras.layers import Dense
from keras.models import Sequential
latent_dim = 128
encoder_inputs = Input(shape=(20,1))
encoder = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]
# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(20, 1))
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)
decoder_dense = Dense(1, activation='tanh')
decoder_outputs = decoder_dense(decoder_outputs)
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(optimizer='adam', loss='mse', metrics=['acc', 'mae'])
history=model.fit(xtrain, xtrain, epochs=200, verbose=2, shuffle=False)
我也有这个模型,但我不确定如何在这里初始化与解码器状态相同的编码器状态。重复向量是这样做的吗?
#define model
model = Sequential()
model.add(LSTM(100, input_shape=(n_timesteps_in, n_features)))
model.add(RepeatVector(n_timesteps_in))
model.add(LSTM(100, return_sequences=True))
model.add(TimeDistributed(Dense(n_features, activation='tanh')))
model.compile(loss='mse', optimizer='adam', metrics=['mae'])
history=model.fit(train, train, epochs=epochs, verbose=2, shuffle=False)
解决方案
您正在构建一个具有 2 个输入的模型,即encoder_inputs
但decoder_inputs
只给一个输入.fit(xtrain, xtrain, ...)
,第二个参数是输出。如果您需要提供另一个形式的参数.fit([xtrain, the_inputs_for_decoder], xtrain, ...)
推荐阅读
- mysql - 在 INSERT 上搜索和更新
- docker - Docker for Mac `docker images` 命令返回“致命错误:故障”
- r - 从并行循环内部将文本附加到现有文件
- c# - 文本到语音 C#
- django - 警告:Django_mysql.w001
- phpunit - SilverStripe 4:FunctionalTest “get” 方法返回 404 状态,尽管页面在那里。
- javascript - 将我的脚本移动到外部页面,它们不再响应
- javascript - Textarea.value 不操纵 HTML
- ansible - 使用 ansible 将文件从远程主机复制到 vagrant 实例
- hadoop - hadoop reducer 的实现是否依赖于一个大的哈希映射来删除所有相同的键?