tensorflow - Seq2Seq 模型为所有句子返回相同的向量
问题描述
我正在尝试生成抽象的文本摘要。我使用 word2vec 进行嵌入,编码器中使用 2 层的 bi-lstm,解码器中使用 1 层的 bi-lstm,而且我正在使用 Attention。我训练了模型,它总是为所有输入句子返回相同的向量。我该如何解决这个问题?
培训代码
latent_dim = 185
embedding_dim=128
encoder_inputs = Input(shape=(int(art_max_length),))
#embedding layer
enc_emb=Embedding(input_vocab_size+1,embedding_dim, weights=[x_emb_matrix_reduce],trainable=False)(encoder_inputs)
#encoder lstm 1
encoder_bi_lstm1 = Bidirectional(LSTM(latent_dim,
return_sequences=True,
return_state=True,
dropout=0.4,
recurrent_dropout=0.4),
merge_mode="concat")
encoder_output1, forward_state_h1, forward_state_c1, backward_state_h1, backward_state_c1 = encoder_bi_lstm1(enc_emb)
encoder_states1 = [forward_state_h1, forward_state_c1, backward_state_h1, backward_state_c1]
#encoder lstm 2
encoder_bi_lstm2 = Bidirectional(LSTM(latent_dim,
return_sequences=True,
return_state=True,
dropout=0.4,
recurrent_dropout=0.4),
merge_mode="concat")
encoder_output2, forward_state_h2, forward_state_c2, backward_state_h2, backward_state_c2 = encoder_bi_lstm2(encoder_output1)
encoder_states2 = [forward_state_h2, forward_state_c2, backward_state_h2, backward_state_c2]
# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None,))
#embedding layer
dec_emb_layer = Embedding(output_vocab_size+1, embedding_dim, weights=[y_emb_matrix_reduce], trainable=False)
dec_emb = dec_emb_layer(decoder_inputs)
decoder_bi_lstm = Bidirectional(LSTM(latent_dim,
return_sequences=True,
return_state=True,
dropout=0.4,
recurrent_dropout=0.2),
merge_mode="concat")
decoder_outputs, decoder_fwd_state_h1, decoder_fwd_state_c1, decoder_back_state_h1, decoder_back_state_c1 = decoder_bi_lstm(dec_emb,initial_state=encoder_states2)
decoder_states = [decoder_fwd_state_h1, decoder_fwd_state_c1, decoder_back_state_h1, decoder_back_state_c1]
# Attention layer
attn_layer = AttentionLayer(name='attention_layer')
attn_out, attn_states = attn_layer([encoder_output2, decoder_outputs])
# Concat attention input and decoder LSTM output
decoder_concat_input = Concatenate(axis=-1, name='concat_layer')([decoder_outputs, attn_out])
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
epochs = 75
batch_size = 3
learning_rate = 0.001
initial_accumulator_value = 0.1
name = 'Adagrad'
clipnorm = 1.0
opt = Adagrad(learning_rate=learning_rate, initial_accumulator_value=initial_accumulator_value, name=name, clipnorm=clipnorm)
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy')
es = EarlyStopping(monitor='val_loss', mode='auto', verbose=1,patience=10)
history=model.fit(x_tr, y_tr, epochs=epochs, callbacks=[es], steps_per_epoch=250, validation_steps=10, batch_size=batch_size, validation_data=(x_val,y_val))
推理代码
reverse_target_word_index = y_tokenizer.index_word
reverse_source_word_index = x_tokenizer.index_word
target_word_index = y_tokenizer.word_index
# Encode the input sequence to get the feature vector
encoder_model = Model(inputs=encoder_inputs, outputs=[encoder_output2, forward_state_h2, forward_state_c2, backward_state_h2, backward_state_c2])
# Decoder setup
# Below tensors will hold the states of the previous time step
decoder_state_input_h_fwd = Input(shape=(latent_dim,))
decoder_state_input_h_bwd = Input(shape=(latent_dim,))
decoder_state_input_c_fwd = Input(shape=(latent_dim,))
decoder_state_input_c_bwd = Input(shape=(latent_dim,))
decoder_hidden_state_input = Input(shape=(art_max_length,latent_dim*2))
# Get the embeddings of the decoder sequence
dec_emb2= dec_emb_layer(decoder_inputs)
# To predict the next word in the sequence, set the initial states to the states from the previous time step
decoder_outputs2, decoder_fwd_state_h2, decoder_fwd_state_c2, decoder_back_state_h2, decoder_back_state_c2 = decoder_bi_lstm(dec_emb2, initial_state=[decoder_state_input_h_fwd, decoder_state_input_h_bwd, decoder_state_input_c_fwd, decoder_state_input_c_bwd])
decoder_states2 = [decoder_fwd_state_h2, decoder_fwd_state_c2, decoder_back_state_h2, decoder_back_state_c2]
#attention inference
attn_out_inf, attn_states_inf = attn_layer([decoder_hidden_state_input, decoder_outputs2])
decoder_inf_concat = Concatenate(axis=-1, name='concat')([decoder_outputs2, attn_out_inf])
# A dense softmax layer to generate prob dist. over the target vocabulary
decoder_outputs2 = decoder_dense(decoder_inf_concat)
# Final decoder model
decoder_model = Model(
[decoder_inputs] + [decoder_hidden_state_input, decoder_state_input_h_fwd, decoder_state_input_h_bwd, decoder_state_input_c_fwd, decoder_state_input_c_bwd],
[decoder_outputs2] + decoder_states2)
生成摘要的代码
def seq2summary(input_seq):
newString=''
for i in input_seq:
if((i[0]!=0) and (i[0]!=target_word_index['sostok']) and (i[0]!=target_word_index['eostok'])):
newString=newString+reverse_target_word_index[i[0]]+' '
return newString
def seq2text(input_seq):
newString=''
for i in input_seq:
if(i!=0):
newString=newString+reverse_source_word_index[i]+' '
return newString
def decode_sequence(input_seq):
e_out, e_h_fwd, e_c_fwd, e_h_bwd, e_c_bwd = encoder_model.predict(input_seq)
# Generate empty target sequence of length 1.
target_seq = np.zeros((1,1))
# Populate the first word of target sequence with the start word.
target_seq[0, 0] = target_word_index['sostok']
stop_condition = False
decoded_sentence = ''
while not stop_condition:
output_tokens, h_fwd, c_fwd, h_bwd, c_bwd = decoder_model.predict([target_seq] + [e_out, e_h_fwd, e_c_fwd, e_h_bwd, e_c_bwd])
return output_tokens[0, -1, :]
# Sample a token
sampled_token_index = np.argmax(output_tokens[0, -1, :])
sampled_token = reverse_target_word_index[sampled_token_index]
if(sampled_token!='eostok'):
decoded_sentence += ' '+sampled_token
# Exit condition: either hit max length or find stop word.
if (sampled_token == 'eostok' or len(decoded_sentence.split()) >= (high_max_length-1)):
stop_condition = True
# Update the target sequence (of length 1).
target_seq = np.zeros((1,1))
target_seq[0, 0] = sampled_token_index
# Update internal states
e_h_fwd, e_c_fwd, e_h_bwd, e_c_bwd = h_fwd, c_fwd, h_bwd, c_bwd
return decoded_sentence
向量的结果:
e_h_fwd
Result:
array([[-0.00384058, -0.0084216 , 0.00099799, -0.00328317, -0.00355412,
0.01573788, -0.00565114, 0.0002754 , -0.01011071, 0.03385576,
0.01035002, 0.0010401 , 0.01606524, 0.00338535, -0.0208919 ,
0.002799 , -0.00558226, -0.00252697, -0.00916545, 0.00482792,
0.00838646, -0.00736981, -0.00089604, -0.00780456, 0.00439578,
0.02386101, -0.01245494, 0.0068648 , -0.01109423, -0.00279979,
-0.0048555 , 0.00291485, -0.00111228, 0.0121593 , 0.00718876,
-0.00367533, 0.00612858, 0.0026198 , -0.00990033, -0.00372838,
0.01660432, 0.01064453, 0.01216934, -0.01671972, -0.021307 ,
0.00358878, -0.00851676, 0.00872963, -0.00098289, -0.00512723,
0.00447382, -0.00086343, 0.00142587, -0.01713295, -0.01154616,
-0.00318079, -0.0213894 , 0.01909565, 0.00537347, 0.00287433,
0.00013318, 0.01882311, -0.00919805, -0.01009239, -0.01000161,
0.00729822, -0.00228036, 0.01970326, -0.00668583, 0.01141307,
-0.00155173, -0.00519767, -0.005886 , 0.00621226, 0.0005807 ,
-0.00401507, -0.02050336, -0.0063515 , -0.0088415 , 0.01226105,
0.00378229, 0.00897009, -0.00173353, -0.00694196, 0.00197844,
-0.0178321 , 0.00554329, 0.01416476, -0.01519079, 0.00422954,
-0.00771015, 0.00344123, -0.01047825, -0.00756182, -0.00108388,
-0.01648704, 0.00209498, 0.0071196 , -0.01291664, -0.00549853,
-0.01216177, 0.0046125 , 0.00120374, -0.00372009, 0.01676877,
-0.00930131, -0.00677394, -0.0162948 , -0.00530502, -0.01685343,
0.01167075, 0.0062821 , -0.01340364, 0.00760005, -0.0337769 ,
0.00708523, -0.00263025, 0.00446939, 0.02564106, -0.00254333,
-0.00707568, 0.01608927, 0.00716687, -0.00965973, -0.00327503,
-0.00604013, 0.0175317 , 0.01505202, -0.00426429, -0.00377769,
-0.00929095, -0.01969613, 0.00719869, -0.01020684, 0.01040385,
0.01139158, -0.0043503 , -0.00274339, -0.00616975, -0.01331878,
0.00295496, -0.01160615, -0.00336138, 0.00886331, -0.02004485,
0.01137386, 0.00428817, -0.00449507, -0.00655314, -0.01015342,
0.02188095, 0.00309571, 0.00742747, 0.02219234, 0.00236926,
-0.00491316, 0.01939732, 0.01722919, 0.00388572, 0.02340838,
-0.01717703, -0.00525931, 0.01344595, -0.00262558, 0.01469047,
0.0196475 , 0.01402889, -0.0011783 , -0.01755165, -0.01247887,
0.01138979, 0.00034305, 0.00225358, -0.01848649, -0.01921862,
0.0028248 , -0.01087625, 0.00121242, -0.02166731, 0.01230442,
0.01093107, 0.01236717, -0.01110782, -0.00536899, -0.01232667]],
dtype=float32)
output_tokens
Result:
array([[[6.1362894e-06, 5.7854427e-06, 5.8488249e-06, ...,
5.7374464e-06, 5.7320071e-06, 5.7324951e-06]]], dtype=float32)
np.argmax(output_tokens[0, -1, :])
Result:
0
解决方案
推荐阅读
- reactjs - 按索引反应选择 OnChange
- php - .htaccess 带有空格的重定向网址
- node.js - 在 Windows 7 上安装 Node JS 卡在命令提示符上
- android - Android 点击事件是如何传递的?
- python - 具有一个热编码特征的 Auto-Sklearn 中的特征和特征重要性
- sql - (Oracle)在主表上添加过滤器是否可以提高主从之间的左连接条件的性能?
- javascript - 我可以从整个视频中接收一些数据并选择性地控制其余部分吗?
- java - Java泛型从参数返回类型的子集合
- unix - 在 AIX 中验证密码
- python - Flask Docker 映像在 Azure 上出现超时错误