tensorflow - Keras / TensorFlow:将常数层连接到卷积
问题描述
对于每个卷积激活图,我想连接一层常量——更具体地说,我想连接一个网格网格。(这是为了复制 Uber 的一篇论文。)
例如,假设我有一个激活图(?, 256, 256, 32)
;然后我想连接一个常量层 shape (?, 256, 256, 1)
。
这就是我这样做的方式:
from keras import layers
import tensorflow as tf
import numpy as np
input_layer = layers.Input((256, 256, 3))
conv = layers.Conv2D(32, 3, padding='same')(input_layer)
print('conv:', conv.shape)
xx, yy = np.mgrid[:256, :256] # [(256, 256), (256, 256)]
xx = tf.constant(xx, np.float32)
yy = tf.constant(yy, np.float32)
xx = tf.reshape(xx, (-1, 256, 256, -1))
yy = tf.reshape(yy, (-1, 256, 256, -1))
print('xx:', xx.shape, 'yy:', yy.shape)
concat = layers.Concatenate()([conv, xx, yy])
print('concat:', concat.shape)
conv2 = layers.Conv2D(32, 3, padding='same')(concat)
print('conv2:', conv2.shape)
但我得到了错误:
conv: (?, 256, 256, 32)
xx: (?, 256, 256, ?) yy: (?, 256, 256, ?)
concat: (?, 256, 256, ?)
Traceback (most recent call last):
File "temp.py", line 21, in <module>
conv2 = layers.Conv2D(32, 3, padding='same')(concat)
[...]
raise ValueError('The channel dimension of the inputs '
ValueError: The channel dimension of the inputs should be defined. Found `None`.
问题是我的常量层是(?, 256, 256, ?)
,而不是(?, 256, 256, 1)
,然后是下一个卷积层错误输出。
我尝试了其他事情但没有成功。
PS:我试图实现的论文已经在这里实现。
解决方案
问题是tf.reshape不能推断出多于一维的形状(即使用-1
多于一维会导致未定义的维度?
)。由于您想要 和 的形状xx
,您yy
可以(?, 256, 256, 1)
按如下方式重塑这些张量:
xx = tf.reshape(xx, (-1, 256, 256, 1))
yy = tf.reshape(yy, (-1, 256, 256, 1))
生成的形状将是(1, 256, 256, 1)
. 现在,conv
is(?, 256, 256, 32)
和keras.layers.Concatenate要求所有输入的形状都匹配,除了 concat 轴。然后,您可以使用tf.tile沿第一个维度重复张量xx
以yy
匹配批量大小:
xx = tf.tile(xx, [tf.shape(conv)[0], 1, 1, 1])
yy = tf.tile(yy, [tf.shape(conv)[0], 1, 1, 1])
xx
和的形状yy
现在是(?, 256, 256, 1)
,并且张量可以连接起来,因为它们的第一个维度与批量大小匹配。
推荐阅读
- angular - Angular 5 - 多个 Mat 选项卡使用子组件的相同实例
- cassandra - CassandraPageRequest 的序列化
- spring - 春季启动推特问题
- android - Recyclerview在android中滚动其他布局
- c# - log4net 1 个附加程序 2 个文件
- solr - 过滤子文档并检索父+子文档
- java - 休眠“自动模式更新”会丢弃它不拥有的任何东西吗?
- chef-infra - 使用 Chef 更新系统用户获取配方编译错误
- mysql - MySQL - 搜索 JSON 类型行以查找打开的商店
- ruby-on-rails - 嵌套验证失败时表单项消失