首页 > 解决方案 > 没有 tf.contrib.rnn.MultiRNNCell 的多层 LSTM

问题描述

为了实现多层 LSTM 网络,我通常使用以下代码:

def lstm_cell():
    return tf.contrib.rnn.LayerNormBasicLSTMCell(model_settings['rnn_size'])
    
attn_cell = lstm_cell
    
def attn_cell():
    return tf.contrib.rnn.DropoutWrapper(lstm_cell(), output_keep_prob=0.7)
    
cell = tf.contrib.rnn.MultiRNNCell([attn_cell() for _ in range(num_layers)], state_is_tuple=True)
outputs_, _ = tf.nn.dynamic_rnn(cell, x, dtype=tf.float32)

但是,这样一来,如果我想操纵隐藏层输出的排列,我就无法访问隐藏层输出。有没有其他方法可以在不使用tf.contrib.rnn.MultiRNNCell的情况下制作多层 LSTM 网络?

标签: tensorflowneural-networklstm

解决方案


您可以简单地堆叠几个 LSTM 层,例如通过 Sequential 模块:

model = Sequential()
model.add(layers.LSTM(..., return_sequences = True, input_shape = (...)))
model.add(layers.LSTM(..., return_sequences = True)
...
model.add(layers.LSTM(...))

在这种情况下,return sequences关键字对于中间层至关重要。


推荐阅读