python - 将编码器从 AutoEncoder 连接到 LSTM
问题描述
我有一个像这样定义的自动编码器
inputs = Input(batch_shape=(1,timesteps, input_dim))
encoded = LSTM(4,return_sequences = True)(inputs)
encoded = LSTM(3,return_sequences = True)(encoded)
encoded = LSTM(2)(encoded)
decoded = RepeatVector(timesteps)(encoded)
decoded = LSTM(3,return_sequences = True)(decoded)
decoded = LSTM(4,return_sequences = True)(decoded)
decoded = LSTM(input_dim,return_sequences = True)(decoded)
sequence_autoencoder = Model(inputs, decoded)
encoder = Model(inputs,encoded)
我希望编码器像这样连接到 LSTM 层
f_input = Input(batch_shape=(1, timesteps, input_dim))
encoder_input = encoder(inputs=f_input)
single_lstm_layer = LSTM(50, kernel_initializer=RandomUniform(minval=-0.05, maxval=0.05))(encoder_input)
drop_1 = Dropout(0.33)(single_lstm_layer)
output_layer = Dense(12, name="Output_Layer"
)(drop_1)
final_model = Model(inputs=[f_input], outputs=[output_layer])
但它给了我一个尺寸错误。
Input 0 is incompatible with layer lstm_3: expected ndim=3, found ndim=2
我怎样才能正确地做到这一点。?
解决方案
我认为主要问题源于最后一个encoded
不是重复向量这一事实。要将编码器输出馈送到 LSTM,它需要通过一个RepeatVector
层发送。换句话说,编码器的最后一个输出需要具有[batch_size, time_steps, dim]
能够输入 LSTM 的形状。这可能是你要找的?
inputs = Input(batch_shape=(1,timesteps, input_dim))
encoded = LSTM(4,return_sequences = True)(inputs)
encoded = LSTM(3,return_sequences = True)(encoded)
encoded = LSTM(2)(encoded)
encoded_repeat = RepeatVector(timesteps)(encoded)
decoded = LSTM(3,return_sequences = True)(encoded_repeat)
decoded = LSTM(4,return_sequences = True)(decoded)
decoded = LSTM(input_dim,return_sequences = True)(decoded)
sequence_autoencoder = Model(inputs, decoded)
encoder = Model(inputs,encoded_repeat)
f_input = Input(batch_shape=(1, timesteps, input_dim))
encoder_input = encoder(inputs=f_input)
single_lstm_layer = LSTM(50, kernel_initializer=RandomUniform(minval=-0.05, maxval=0.05))(encoder_input)
drop_1 = Dropout(0.33)(single_lstm_layer)
output_layer = Dense(12, name="Output_Layer"
)(drop_1)
final_model = Model(inputs=[f_input], outputs=[output_layer])
我已将您的第一个重命名decoded
为encode_repeat
推荐阅读
- r - 一次将数据框中的所有 0 值替换为多行和多列的 1
- javascript - MDN webRTC still photo capture demo stops working recently
- r - Extract value from tibble in R markdown inline code
- java - javax.crypto.IllegalBlockSizeException:密码函数:OPENSSL_internal:WRONG_FINAL_BLOCK_LENGTH
- ios - Crashed: com.twitter.crashlytics.ios.exception IOS
- java - 使用 Long 类型的对象参数对对象列表进行排序
- python - How to pull year and month from every month between date range into list?
- python - Converting TeX file to image in Python
- angularjs - AngularJS, How to call a function in component from Directive?
- python - Timestamping each frame of a video