tensorflow - 使用 K.tile() 复制张量
问题描述
我有张量(None, 196)
,在重塑之后,它变成了(None, 14, 14)
。现在,我想将它复制到通道轴,这样形状应该是(None, 14, 14, 512)
. 最后,我想复制到时间步长轴,所以它变成(None, 10, 14, 14, 512)
. 我使用此代码段完成这些步骤:
def replicate(tensor, input_target):
batch_size = K.shape(tensor)[0]
nf, h, w, c = input_target
x = K.reshape(tensor, [batch_size, 1, h, w, 1])
# Replicate to channel dimension
x = K.tile(x, [batch_size, 1, 1, 1, c])
# Replicate to timesteps dimension
x = K.tile(x, [batch_size, nf, 1, 1, 1])
return x
x = ...
x = Lambda(replicate, arguments={'input_target':input_shape})(x)
another_x = Input(shape=input_shape) # shape (10, 14, 14, 512)
x = layers.multiply([x, another_x])
x = ...
我绘制模型,输出形状就像我想要的那样。但是,问题出现在模型训练中。我将批量大小设置为 2。这是错误消息:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [8,10,14,14,512] vs. [2,10,14,14,512]
[[{{node multiply_1/mul}} = Mul[T=DT_FLOAT, _class=["loc:@training/Adam/gradients/multiply_1/mul_grad/Sum"], _device="/job:localhost/replica:0/task:0/device:GPU:0"](Lambda_2/Tile_1, _arg_another_x_0_0/_189)]]
[[{{node metrics/top_k_categorical_accuracy/Mean_1/_265}} = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_6346_metrics/top_k_categorical_accuracy/Mean_1", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
看起来,K.tile()
将批量大小从 2 增加到 8。当我将批量大小设置为 10 时,它变为 1000。
所以,我的问题是如何达到我想要的结果?使用方法好不好tile()
?或者,我应该使用repeat_elements()
吗?谢谢!
我正在使用 Tensorflow 1.12.0 和 Keras 2.2.4。
解决方案
根据经验,尽量避免将批量大小引入Lambda
层中发生的转换。
当您使用tile
操作时,您只设置需要更改的维度(例如,您batch_size
在平铺操作中有值是错误的)。我也在使用tf.tile
而不是K.tile
(TF 1.12 在 Keras 后端似乎没有磁贴)。
def replicate(tensor, input_target):
_, nf, h, w, c = input_target
x = K.reshape(tensor, [-1, 1, h, w, 1])
# Replicate to channel dimension
# You can combine below lines to tf.tile(x, [1, nf, 1, 1, c]) as well
x = tf.tile(x, [1, 1, 1, 1, c])
# Replicate to timesteps dimension
x = tf.tile(x, [1, nf, 1, 1, 1])
return x
简单的例子
input_shape= [None, 10, 14, 14, 512]
x = Input(shape=(196,))
x = Lambda(replicate, arguments={'input_target':input_shape})(x)
print(x.shape)
这使
>>> (?, 10, 14, 14, 512)
推荐阅读
- scala - 将 null 添加到 int 列
- python - 如何使用百分比值绘制进度饼图?
- javascript - 在javascript对象中,如果标签存在于另一个字段中,如何对一个字段的值求和?
- r - R markdown:在乳胶表中使用 R 代码
- node.js - 将数据导入 Sequelize + Heroku Postgres 时会导致丢失行的原因是什么?
- python - Python setuptools:打包根目录(每个包不需要子目录)
- awk - yum list 获取最后一个可用的包
- java - 如何访问android 11中的内部存储?
- amazon-web-services - 使用远程桌面时网页上出现错误 403
- python - 如何实现装饰器来强制 Python 类型提示?