python - tf.placeholder_with_default 反向传播
问题描述
我正在尝试在 Tensorflow 中训练自动编码器。然而,这个自动编码器只是我架构的一部分。我想要以下工作流程:
输入图像-->上游处理的数据-->输入自动编码器-->使用编码图像输出及其梯度。
换句话说,我想要整个编码操作的梯度,包括上游数据处理,wrt 输入图像数据。
我也想自己训练我的自动编码器。因此,我认为我可以将输入图像制作成 type 的自动编码器tf.placeholder_with_default
。我的想法是我可以将上游数据处理直接连接到自动编码器输入作为默认值,但也可以允许用户传入训练数据进行训练。
因此,我这样构造自动编码器的输入:
input_x = tf.Variable(tf.zeros(dtype=tf.float32, shape=[1, 60, 200, 3])) #Will be fed in from upstream, for now, zeros is just for testing
self.x = tf.placeholder_with_default(input_x, shape=[None, 60, 200, 3], name='camera') #images are 200 x 60 with 3 channels; x is the input to the autoencoder
我的自动编码器涉及多次调用tf.nn.conv2d
. 不幸的是,当我尝试使用此设置进行训练时,出现以下错误:
InvalidArgumentError (see above for traceback): Conv2DSlowBackpropInput: input and out_backprop must have the same batch sizeinput batch: 1outbackprop batch: 32 batch_dim: 0
当我将代码更改为:
self.x = tf.placeholder(tf.float32, shape=[None, 60, 200, 3], name='camera')
我没有问题。
我是否尝试tf.placeholder_with_default
正确使用?什么可能导致错误?(我可以提供更多代码,但如果可能的话,我不想发布我的整个 AE)。
解决方案
尝试将您的 input_x 声明替换为:
input_x =np.zeros([1, 60, 200, 3], dtype=np.float32)
推荐阅读
- java - Java 8 / Fernflower 反编译器:错误或功能
- html - 当表单有效或在控制器中引用表单时,无法让 ng-message 消失
- excel - Excel公式计算单个单元格中的多个可能组合
- php - 在 PHP 中避免空值时,我可以重新实现 isset 吗?
- c# - 通过表单发布的数据挂起应用程序
- travis-ci - 使用 Travis CI 发布到 NPMJS
- angular - 使用自定义指令值选择性地显示表单元素
- aws-lambda - 在 AWS Lambda 中增加重试次数
- python - 无法在 Windows 上使用 winsound 播放 .mp3 文件
- powershell - 处理列表时的 Powershell 错误处理