tensorflow - keras implementation of a parallel convolution layer
问题描述
learning keras and cnn in general, so tried to implement a network i found in a paper, in it there is a parallel convolution layer of 3 convs where each conv apply a different filter on the input, here how i tried to solve it:
inp = Input(shape=(32,32,192))
conv2d_1 = Conv2D(
filters = 32,
kernel_size = (1, 1),
strides =(1, 1),
activation = 'relu')(inp)
conv2d_2 = Conv2D(
filters = 64,
kernel_size = (3, 3),
strides =(1, 1),
activation = 'relu')(inp)
conv2d_3 = Conv2D(
filters = 128,
kernel_size = (5, 5),
strides =(1, 1),
activation = 'relu')(inp)
out = Concatenate([conv2d_1, conv2d_2, conv2d_3])
model.add(Model(inp, out))
-this gives me the following err : A Concatenate layer requires inputs with matching shapes except for the concat axis....etc
.
- i tried solving it by adding the arg
input_shape = inp
in every Conv2D function, now it gives meCannot iterate over a tensor with unknown first dimension.
ps : the paper writers implemented this network with caffe, the input to this layer is (32,32,192) and the output after the merge is (32,32,224).
解决方案
除非您添加填充以匹配数组形状,Concatenate
否则将无法匹配它们。尝试运行这个
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Concatenate
inp = Input(shape=(32,32,192))
conv2d_1 = Conv2D(
filters = 32,
kernel_size = (1, 1),
strides =(1, 1),
padding = 'SAME',
activation = 'relu')(inp)
conv2d_2 = Conv2D(
filters = 64,
kernel_size = (3, 3),
strides =(1, 1),
padding = 'SAME',
activation = 'relu')(inp)
conv2d_3 = Conv2D(
filters = 128,
kernel_size = (5, 5),
strides =(1, 1),
padding = 'SAME',
activation = 'relu')(inp)
out = Concatenate()([conv2d_1, conv2d_2, conv2d_3])
model = tf.keras.models.Model(inputs=inp, outputs=out)
model.summary()
推荐阅读
- hangouts-chat - 点击 Google Chat 上的 URL → 我要启动 Firefox
- python-3.x - Kivy + 多处理引发 TypeError
- f# - 如何根据内容按顺序对消息进行分组
- javascript - 如何在jQuery提交表单之前检查文件的内容是否已被修改?
- powerquery - Power Query:基于值的重复行
- xpath - IMPORTXHTML 提供无法从表格中的 nasdaq 获取 url 错误
- wpf - 如何在 DataGridTextColumn 的文本单击上执行命令
- go - 从 gocron 任务返回输出数据
- python - urllib.error.HTTPError:HTTP 错误 404:未使用 Google Matrix API 找到
- php - 在托管在亚马逊服务器上的网站中使用 Google Cloud Storage 遇到错误 500