python - 两个序列到序列模型keras之间的区别(有和没有RepeatVector)
问题描述
我试图了解这个模型在这里描述的区别是什么,以下是:
from keras.layers import Input, LSTM, RepeatVector
from keras.models import Model
inputs = Input(shape=(timesteps, input_dim))
encoded = LSTM(latent_dim)(inputs)
decoded = RepeatVector(timesteps)(encoded)
decoded = LSTM(input_dim, return_sequences=True)(decoded)
sequence_autoencoder = Model(inputs, decoded)
encoder = Model(inputs, encoded)
这里描述的序列到序列模型是 第二个描述
有什么不同 ?第一个有RepeatVector,而第二个没有?第一个模型是否没有将解码器隐藏状态作为预测的初始状态?
有没有描述第一个和第二个的论文?
解决方案
在使用 的模型中RepeatVector
,他们没有使用任何花哨的预测,也没有处理状态。他们让模型在内部完成所有操作,并且RepeatVector
用于将(batch, latent_dim)
向量(不是序列)转换为 a (batch, timesteps, latent_dim)
(现在是正确的序列)。
现在,在另一个没有 的模型中,RepeatVector
秘密在于这个附加功能:
def decode_sequence(input_seq):
# Encode the input as state vectors.
states_value = encoder_model.predict(input_seq)
# Generate empty target sequence of length 1.
target_seq = np.zeros((1, 1, num_decoder_tokens))
# Populate the first character of target sequence with the start character.
target_seq[0, 0, target_token_index['\t']] = 1.
# Sampling loop for a batch of sequences
# (to simplify, here we assume a batch of size 1).
stop_condition = False
decoded_sentence = ''
while not stop_condition:
output_tokens, h, c = decoder_model.predict([target_seq] + states_value)
# Sample a token
sampled_token_index = np.argmax(output_tokens[0, -1, :])
sampled_char = reverse_target_char_index[sampled_token_index]
decoded_sentence += sampled_char
# Exit condition: either hit max length
# or find stop character.
if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):
stop_condition = True
# Update the target sequence (of length 1).
target_seq = np.zeros((1, 1, num_decoder_tokens))
target_seq[0, 0, sampled_token_index] = 1.
# Update states
states_value = [h, c]
return decoded_sentence
这会运行一个基于 a 的“循环”,stop_condition
用于一个一个地创建时间步长。(这样做的好处是使句子没有固定长度)。
它还显式地获取每个步骤中生成的状态(以保持每个单独步骤之间的正确连接)。
简而言之:
- 模型 1:通过重复潜在向量来创建长度
- 模型 2:通过循环新步骤直到达到停止条件来创建长度
推荐阅读
- python - Python请求或beautifulsoup4延迟
- python - 带有嵌套for循环的Python多维数组
- php - laravel eloquent 用多表序列化自定义关系
- python-import - 移动了类位置,现在无法腌制该类的对象
- c# - 超过双倍最大值
- python - 如何安装python2.7版本的pip?
- python - 如何使用 boto3 轮换我的 AWS IAM 用户访问权限和密钥?
- flutter - Flutter中根据组名导航到不同的页面
- opengl - OpenGL Nvidia驱动程序错误?
- qemu - 如何在 qemu 上模拟支持安全功能的 sata 设备?