首页 > 解决方案 > 如何在 tensforflow 2.0 中替换 OutputProjectionWrapper

问题描述

我有以下带有注意机制的 seq2seq 解码器的代码片段。它适用于 tensorflow 1.13。现在我需要使用 keras 升级到 tensorflow 2.0,但是 tf.contrib.rnn.OutputProjectionWrapper 已经在 tensorflow 2.0 中删除了。如何实施?

attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                num_units, memory=memory,
         memory_sequence_length=self.encoder_inputs_actual_length)
cell = tf.contrib.rnn.LSTMCell(num_units)
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
                cell, attention_mechanism, attention_layer_size)
out_cell = tf.contrib.rnn.OutputProjectionWrapper(
                attn_cell, self.output_size, reuse=reuse)
decoder = tf.contrib.seq2seq.BasicDecoder(
                cell=out_cell, helper=helper,
                initial_state=out_cell.zero_state(
                    dtype=tf.float32, batch_size=self.batch_size))
final_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder=decoder, output_time_major=True,
                impute_finished=True, 
maximum_iterations=self.input_steps
            )

我阅读了https://www.oreilly.com/library/view/neural-networks-and/9781492037354/ch04.html但没有弄清楚如何将完整的连接添加到我的案例中。

我尝试使用带有急切模式的最新 seq2seq 插件,如下所示,没有语法错误,但我不确定它是否正确。之前的 tf 1.13 版本预测准确率快速达到 90%,但新的 tf2.0 版本准确率始终在 60% 左右。

attention_mechanism = tfa.seq2seq.BahdanauAttention(num_units,memory,memory_sequence_length)
lstm_cell = layers.LSTMCell(num_units)
attn_cell = tfa.seq2seq.AttentionWrapper(lstm_cell,attention_mechanism, attention_layer_size=num_units) 
output_layer = layers.Dense(self.output_size)
basic_decoder = tfa.seq2seq.BasicDecoder(cell=attn_cell, sampler=sampler,output_layer=output_layer,output_time_major=True,impute_finished=True,maximum_iterations=self.input_steps)
initial_state = attn_cell.get_initial_state(batch_size=self.batch_size,dtype=tf.float32).clone(cell_state=encoder_final_state)
final_outputs, _, _ = basic_decoder(encoder_outputs_sequence,initial_state=initial_state)

谢谢。

标签: tensorflowkerastensorflow2.0

解决方案


我终于弄清楚了准确率保持在 60% 左右的原因是 AttentionWrapper 默认会输出注意力分数,但在我的情况下,我需要实际输出来计算下一个注意力分数。解决方法是在 AttentionWrapper 中设置 output_attention=False:

attn_cell = tfa.seq2seq.AttentionWrapper(lstm_cell,attention_mechanism, 
attention_layer_size=num_units, output_attention=False) 

在这里更新它以防有人遇到同样的问题。


推荐阅读