首页 > 解决方案 > 从 tf.nn.dynamic_rnn 获取非填充条目的最后输出

问题描述

我想在多对一模式下使用 RNN(最后只有一个输出)。在 TensorFlow 中,可以使用:

lstm_cell = tf.nn.rnn_cell.LSTMCell(lstm_num_units)

output, _  = tf.nn.dynamic_rnn(lstm_cell, embed, dtype=tf.float32)

其中输出包含所有时间步的输出[0, max_time-1],并且max_time是批处理中最长输入的长度。

现在,我想获得批处理中每个输入的最后一个输出。让我更清楚。我在网上看到的所有实现都output[:,-1]用作最后的输出。但是,对于已填充的输入,这意味着输出来自填充的输入。

因此,问题:

  1. 使用 just 是多么合理output[:,-1]

  2. 是否有一种简单的方法可以为 TensorFlow 中的非填充值选择最后一个条目,通常,对于批处理中的每个输入,该条目将在不同的时间步长。不知何故,我发现使用 TensorFlow 张量进行必要的操作有点困难,即使我拥有所有输入序列的原始长度。

标签: pythontensorflow

解决方案


修改您的代码:

_, state  = tf.nn.dynamic_rnn(lstm_cell, embed, dtype=tf.float32, sequence_length=some_placeholder)

last_output = state.h

如果您希望序列长度发生变化,请不要忘记sequence_length在调用中添加参数。dynamic_rnn

或者,您可以使用tf.gather_nd.


推荐阅读