首页 > 解决方案 > 使用 tf.while_loop (TensorFlow) 从图中累积输出

问题描述

长话短说,我有一个堆叠在 CNN 之上的 RNN。CNN 是单独创建和训练的。为了澄清事情,让我们假设 CNN 以 [BATCH SIZE, H, W, C] 占位符的形式接受输入(H = 高度,W = 宽度,C = 通道数)。

现在,当堆叠在 RNN 之上时,组合网络的整体输入将具有以下形状:[BATCH SIZE, TIME SEQUENCE, H, W, C],即 minibatch 中的每个样本由 TIME_SEQUENCE 多个图像组成。此外,时间序列的长度是可变的。有一个名为sequence_lengths[BATCH SIZE] 形状的单独占位符,其中包含与小批量中每个样本的长度相对应的标量值。TIME SEQUENCE 的值对应于最大可能的时间序列长度,对于长度较小的样本,剩余的值用零填充。

我想做的事

我想以 [BATCH SIZE, TIME SEQUENCE, 1] 形状的张量累积 CNN 的输出(最后一个维度只包含 CNN 为每个批次元素的每个时间样本输出的最终分数),以便我可以转发将整个信息块传递给 RNN,并堆叠在 CNN 之上。棘手的是,我还希望能够将错误从 RNN 反向传播到 CNN(CNN 已经预训练,但我想稍微微调一下权重),所以我必须留在图表内,即我不能对session.run().

在里面my_cnn_model.process_input,我只是通过香草 CNN 传递输入。在其中创建的所有变量都带有tf.AUTO_REUSE,因此应该确保 while 循环为所有循环迭代重用相同的权重。

确切的问题

image_output_sequence是一个变量,但不知何故,当tf.while_loop调用该body方法时,它会变成一个无法对其进行赋值的张量类型对象。我收到错误消息:Sliced assignment is only supported for variables

即使我使用另一种格式(例如使用 BATCH SIZE 张量的元组,每个元组的维度为 [TIME SEQUENCE、H、W、C]),这个问题仍然存在。

我也愿意对代码进行彻底的重新设计,只要它能很好地完成工作。

标签: pythontensorflowdeep-learningrecurrent-neural-network

解决方案


解决方案是使用TensorArray专门为解决此类问题而设计的类型对象。以下行:

image_output_sequence = tf.Variable(tf.zeros([batch_size, max_sequence_length, 1], tf.float32))

替换为:

image_output_sequence = tf.TensorArray(size=batch_size, dtype=tf.float32, element_shape=[max_sequence_length, 1], infer_shape=True)

TensorArray实际上并不需要每个元素都有固定的形状,但就我而言,它是固定的,所以最好强制执行。

然后在body函数内部,替换这个:

ios[lc].assign(padded_cnn_features)

和:

ios = ios.write(lc, padded_cnn_output)

然后在tf.while_loop语句之后,TensorArray可以堆叠形成一个正则Tensor进行进一步处理:

stacked_tensor = result.stack()

推荐阅读