python - Tensorflow BahdanauAttention - 层 memory_layer 需要 1 个输入,但它接收到 2 个输入张量
问题描述
张量流:1.12
我正在使用bidirectional_dynamic_rnn
. 我写信encoder_output
给 BahdanauAttention 的内存选项(在 tensorflow 网站上推荐),但它抛出了一个错误:
ValueError: Layer memory_layer 需要 1 个输入,但它接收到 2 个输入张量。收到的输入:tf.Tensor 'bidirectional_rnn/fw/fw/transpose_1:0' shape=(?, ?, 512) dtype=float32, tf.Tensor 'ReverseSequence:0' shape=(?, ?, 512) dtype=float32 ]
def model_inputs():
inputs = tf.placeholder(tf.int32, [None, None], name='input')
targets = tf.placeholder(tf.int32, [None, None], name='target')
lr = tf.placeholder(tf.float32, name='learning_rate')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
return inputs, targets, lr, keep_prob
def preprocess_targets(targets, word2int, batch_size):
left_side = tf.fill([batch_size, 1], word2int['<SOS>'])
right_side = tf.strided_slice(targets, [0,0], [batch_size, -1], [1,1])
preprocessed_targets = tf.concat([left_side, right_side], 1)
return preprocessed_targets
#Encoder RNN
def encoder_rnn(rnn_inputs, rnn_size, num_layers, keep_prob, sequence_lenght):
lstm = tf.contrib.rnn.BasicLSTMCell(rnn_size)
lstm_dropout = tf.contrib.rnn.DropoutWrapper(lstm, input_keep_prob = keep_prob)
encoder_cell = tf.contrib.rnn.MultiRNNCell([lstm_dropout] * num_layers)
global encoder_output, encoder_state
encoder_output, encoder_state = tf.nn.bidirectional_dynamic_rnn(cell_fw = encoder_cell,
cell_bw = encoder_cell,
sequence_length = sequence_length,
inputs = rnn_inputs,
dtype = tf.float32)
return encoder_state
#Decoding training set
def decode_training_set(encoder_state, decoder_cell, decoder_embedded_input, sequence_lenght, decoding_scope, output_function, keep_prob, batch_size):
#attention_states = tf.zeros([batch_size, 1, decoder_cell.output_size])
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units = decoder_cell.output_size, memory = encoder_output, normalize=False)
我能做些什么?
解决方案
推荐阅读
- excel - Excel VBA - UDF 返回 0 或空或 #value
- wordpress - WordPress - 主机通过 FTP 或 sFTP 将目录上传到远程服务器
- javascript - 如何在 Vue.js 中使用图像作为按钮?
- php - 在第 18 行调用 public_html/wp-content/themes/g5-beyot/header.php 中未定义的函数 g5plus_get_option()
- echarts - 鼠标离开图表区域时如何停止画笔拖动
- android - 当应用程序暂停需要建议时,android workmanager 不工作
- python-3.x - python文件中的2to3 ParseError
- scala - 如何为每个测试场景加载配置?
- python - Python有效地从URL下载图像
- python - 在一台 PC 上但在另一台 PC 上找不到包中的子模块