首页 > 解决方案 > 如何将双向 LSTM 状态传递给较早的 LSTM 层?

问题描述

我正在尝试使用编码器 LSTM 和解码器 LSTM 做一个 seq2seq 模型,两者都有Bidirectional层。

我可以将隐藏状态和存储单元转发给解码器 LSTM,但我看不出我如何将值从解码器传回编码器。

def sequence_model(total_words, emb_dimension, lstm_units):
    # Encoder
    encoder_input = Input(shape=(None,), name="Enc_Input")
    x = Embedding(total_words, emb_dimension, input_length=max_sequence_length, name="Enc_Embedding")(encoder_input)
    x, state_h, state_c, _, _ = Bidirectional(LSTM(lstm_units, return_state=True, name="Enc_LSTM1"), name="Enc_Bi1")(x) # pass hidden activation and memory cell states forward
    encoder_states = [state_h, state_c] # package states to pass to decoder
    
    # Decoder
    decoder_input = Input(shape=(None,), name="Dec_Input")
    x = Embedding(total_words, emb_dimension, name="Dec_Embedding")(decoder_input)
    x = LSTM(lstm_units, return_sequences=True, name="Dec_LSTM1")(x, initial_state=encoder_states)
    decoder_output = Dense(total_words, activation="softmax", name="Dec_Softmax")(x)

    func_model = tf.keras.Model(inputs=[encoder_input,decoder_input], outputs=decoder_output)
    return func_model

前向状态被传递给initial_state解码器 LSTM 层。Dec_LSTM1但是如果我用一个层包裹这个层Bidirectional,它不喜欢我传递initial_state值并中断。

我是否认为我不需要来自编码器 LSTM 层的向后状态?

附件是我正在尝试实现的架构的图像。

在此处输入图像描述

标签: pythontensorflowkeraslstm

解决方案


当您添加Bidirectional到解码器时,您的代码会中断,因为您从编码器状态中遗漏了两个元素。

x, state_h, state_c, _, _ = ...
#                    ^  ^
# -------------------|--|

LSTM 状态中有两个张量,形状为(batch, hidden);当您在两个方向上运行 LSTM 时,这将增加两个状态(反向传递)。

import tensorflow as tf

from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import Bidirectional
from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import Dense


enc_in = Input(shape=(None,))
enc_x = Embedding(1024, 128, input_length=92)(enc_in)

# vanilla LSTM
s_enc_x, *s_enc_state = LSTM(256, return_state=True)(enc_x)

print(len(s_enc_state))
print(s_enc_state)
# 2
# [<KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'lstm_7')>,
#  <KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'lstm_7')>]

# bi-directional LSTM
bi_enc_x, *bi_enc_state = Bidirectional(LSTM(256, return_state=True))(enc_x)

print(len(bi_enc_state))
print(bi_enc_state)
# 4
# [<KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'bidirectional_6')>,
#  <KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'bidirectional_6')>,
#  <KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'bidirectional_6')>,
#  <KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'bidirectional_6')>]

# decoder
dec_in = Input(shape=(None,))
dec_x = Embedding(1024, 128, input_length=92)(dec_in)
dec_x = Bidirectional(LSTM(256, return_sequences=True))(
    dec_x, initial_state=bi_enc_state)  # <= use bidirectional state
output = Dense(1024, activation="softmax")(dec_x)

print(output.shape)
# TensorShape([None, None, 1024])

推荐阅读