python - 使用 seq2seq tensorflow 插件放置更多 RNN 层
问题描述
当我尝试使用 tensorflow 的 seq2seq 插件进行语言翻译的 seq2seq 模型时遇到问题。我正在关注本教程:https ://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt ,它展示了如何创建一个只有一个 RNN 层(在本例中为 LSTM)的解码器,但我不知道如何用这个插件放置不止一层,以及如何初始化它们。
下面是解码器类的代码:
def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz, attention_type='luong'):
super(Decoder, self).__init__()
self.batch_sz = batch_sz
self.dec_units = dec_units
self.attention_type = attention_type
# Embedding Layer
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
#Final Dense layer on which softmax will be applied
self.fc = tf.keras.layers.Dense(vocab_size)
# Define the fundamental cell for decoder recurrent structure
self.decoder_rnn_cell = tf.keras.layers.LSTMCell(self.dec_units)
# Sampler
self.sampler = tfa.seq2seq.sampler.TrainingSampler()
# Create attention mechanism with memory = None
self.attention_mechanism = self.build_attention_mechanism(self.dec_units,
None, self.batch_sz*[max_length_input], self.attention_type)
# Wrap attention mechanism with the fundamental rnn cell of decoder
self.rnn_cell = self.build_rnn_cell(batch_sz)
# Define the decoder with respect to fundamental rnn cell
self.decoder = tfa.seq2seq.BasicDecoder(self.rnn_cell, sampler=self.sampler, output_layer=self.fc)
def build_rnn_cell(self, batch_sz):
rnn_cell = tfa.seq2seq.AttentionWrapper(self.decoder_rnn_cell,
self.attention_mechanism, attention_layer_size=self.dec_units)
return rnn_cell
def build_attention_mechanism(self, dec_units, memory, memory_sequence_length, attention_type='luong'):
# ------------- #
# typ: Which sort of attention (Bahdanau, Luong)
# dec_units: final dimension of attention outputs
# memory: encoder hidden states of shape (batch_size, max_length_input, enc_units)
# memory_sequence_length: 1d array of shape (batch_size) with every element set to max_length_input (for masking purpose)
if(attention_type=='bahdanau'):
return tfa.seq2seq.BahdanauAttention(units=dec_units, memory=memory, memory_sequence_length=memory_sequence_length)
else:
return tfa.seq2seq.LuongAttention(units=dec_units, memory=memory, memory_sequence_length=memory_sequence_length)
def build_initial_state(self, batch_sz, encoder_state, Dtype):
decoder_initial_state = self.rnn_cell.get_initial_state(batch_size=batch_sz, dtype=Dtype)
decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
return decoder_initial_state
def call(self, inputs, initial_state):
x = self.embedding(inputs)
outputs, _, _ = self.decoder(x, initial_state=initial_state, sequence_length=self.batch_sz*[max_length_output-1])
return outputs
解决方案
推荐阅读
- google-cloud-platform - 获取服务帐户的访问令牌时出错:使用服务帐户使用日历 API 时出现 401 未授权
- docker - docker-compose 的两种“命令”形式的行为是否不同?
- php - 在 WordPress 编辑器中替换 TinyMCE 字体 - 默认字体不会消失
- javascript - 如何在左右滑动行和列时降低响应鼠标事件的灵敏度?
- java - 打印没有额外空白或行的二维数组
- html - 为什么添加 `sm-4` 会使嵌入视频的边距过宽?
- r - 格式化数据并使其正常运行的问题
- python - 获取每个分层熊猫系列的第一行
- flutter - 带有文本的图像全屏(_getTellText(text)扩展颤动
- c++ - 在 C++ 中声明函数的不同方法