首页 > 解决方案 > 无法理解 tf.nn.raw_rnn

问题描述

在我们的官方文档中,tf.nn.raw_rnn我们将发出结构作为第一次运行loop_fn时的第三个输出。loop_fn

稍后,emit_structure 用于复制tf.zeros_like(emit_structure)到由 完成的小批量条目emit = tf.where(finished, tf.zeros_like(emit_structure), emit)

我对谷歌的缺乏理解或糟糕的文档是:emit structure is Noneso tf.where(finished, tf.zeros_like(emit_structure), emit)will throw a ValueError as do tf.zeros_like(None)so. 有人可以填写我在这里缺少的内容吗?

标签: python-3.xtensorflowrecurrent-neural-networkrnntensorflow-slim

解决方案


是的,这个地方的医生很混乱。如果您查看 的内部结构tf.nn.raw_rnn,则其中的关键术语是“伪代码”,因此文档中的示例不准确。

确切的源代码如下所示(可能因您的 tensorflow 版本而异):

if emit_structure is not None:
  flat_emit_structure = nest.flatten(emit_structure)
  flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
                    array_ops.shape(emit) for emit in flat_emit_structure]
  flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
else:
  emit_structure = cell.output_size
  flat_emit_size = nest.flatten(emit_structure)
  flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)

所以它处理的情况是什么时候emit_structure is None简单地取值cell.output_size。这就是为什么没有什么东西真的会破裂。


推荐阅读