tensorflow - 如何在 tensforflow 2.0 中替换 OutputProjectionWrapper
问题描述
我有以下带有注意机制的 seq2seq 解码器的代码片段。它适用于 tensorflow 1.13。现在我需要使用 keras 升级到 tensorflow 2.0,但是 tf.contrib.rnn.OutputProjectionWrapper 已经在 tensorflow 2.0 中删除了。如何实施?
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
num_units, memory=memory,
memory_sequence_length=self.encoder_inputs_actual_length)
cell = tf.contrib.rnn.LSTMCell(num_units)
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
cell, attention_mechanism, attention_layer_size)
out_cell = tf.contrib.rnn.OutputProjectionWrapper(
attn_cell, self.output_size, reuse=reuse)
decoder = tf.contrib.seq2seq.BasicDecoder(
cell=out_cell, helper=helper,
initial_state=out_cell.zero_state(
dtype=tf.float32, batch_size=self.batch_size))
final_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
decoder=decoder, output_time_major=True,
impute_finished=True,
maximum_iterations=self.input_steps
)
我阅读了https://www.oreilly.com/library/view/neural-networks-and/9781492037354/ch04.html但没有弄清楚如何将完整的连接添加到我的案例中。
我尝试使用带有急切模式的最新 seq2seq 插件,如下所示,没有语法错误,但我不确定它是否正确。之前的 tf 1.13 版本预测准确率快速达到 90%,但新的 tf2.0 版本准确率始终在 60% 左右。
attention_mechanism = tfa.seq2seq.BahdanauAttention(num_units,memory,memory_sequence_length)
lstm_cell = layers.LSTMCell(num_units)
attn_cell = tfa.seq2seq.AttentionWrapper(lstm_cell,attention_mechanism, attention_layer_size=num_units)
output_layer = layers.Dense(self.output_size)
basic_decoder = tfa.seq2seq.BasicDecoder(cell=attn_cell, sampler=sampler,output_layer=output_layer,output_time_major=True,impute_finished=True,maximum_iterations=self.input_steps)
initial_state = attn_cell.get_initial_state(batch_size=self.batch_size,dtype=tf.float32).clone(cell_state=encoder_final_state)
final_outputs, _, _ = basic_decoder(encoder_outputs_sequence,initial_state=initial_state)
谢谢。
解决方案
我终于弄清楚了准确率保持在 60% 左右的原因是 AttentionWrapper 默认会输出注意力分数,但在我的情况下,我需要实际输出来计算下一个注意力分数。解决方法是在 AttentionWrapper 中设置 output_attention=False:
attn_cell = tfa.seq2seq.AttentionWrapper(lstm_cell,attention_mechanism,
attention_layer_size=num_units, output_attention=False)
在这里更新它以防有人遇到同样的问题。
推荐阅读
- ios - 拍摄照片时未使用 AVCaptureDevice 的设置
- python-3.x - 我在 selenium 'NoneType' 对象没有属性 'options' 上收到此错误
- ios - Unity - 多显示器(iPad + 显示器)
- twitter-bootstrap - 如何消除引导程序中并排元素之间的间隙?
- reactjs - Ag Grid Enterprise 功能
- python - 关于 pandas groupby 应用列作为参数
- php - 将 nodeValue 与 Switch Case 匹配
- java - JVM 如何为对象分配内存,尽管它们以后可能会变大
- json - 如何将类型 json 文件加载到 DropdownButton
- ios - iOS 使用 OC 通讯录右侧滑动检索波效果