pytorch - pytorch 模型的多个分支
问题描述
嗨,我正在尝试使用 pytorch 制作这个模型。
每个输入由 20 张大小为 28 X 28 的图像组成,即图像中的 C1 ~ Cp。每个图像都进入相同结构的 CNN,但它们的输出最终被连接起来。
我目前正在努力为每个各自的 CNN 模型提供多个输入。第一个带有三个卷积层的模型中的每个模型看起来都是这样的代码,但我不太确定如何将 20 个不同的输入放入相同结构的不同模型中以最终连接起来。
self.features = nn.Sequential(
nn.Conv2d(1,10, kernel_size = 3, padding = 1),
nn.ReLU(),
nn.Conv2d(10, 14, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(14, 18, kernel_size=3, padding=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(28*28*18, 256)
)
我已经尝试提供一个输入列表作为转发功能的输入,但它最终出现错误并且不会通过。如果有任何不清楚的地方,我将非常乐意进一步解释。
解决方案
假设每条路径都有自己的权重,可能这可以通过分组卷积来完成,尽管预融合Linear
可能会造成一些麻烦。
P = 20
self.features = nn.Sequential(
nn.Conv2d(1*P,10*P, kernel_size = 3, padding = 1, groups = P ),
nn.ReLU(),
nn.Conv2d(10*P, 14*P, kernel_size=3, padding=1, groups = P),
nn.ReLU(),
nn.Conv2d(14*P, 18*P, kernel_size=3, padding=1, groups = P),
nn.ReLU(),
nn.Conv2d(18*P, 256*P, kernel_size=28, groups = P), # not shure about this one
nn.Flatten(),
nn.Linear(256*P, 1024 )
)
推荐阅读
- django - Django Rest Framework 传递给 ModelSerializer 的参数是什么意思?
- fonts - WebStorm 2018 项目窗格:如何更改字体颜色?
- firewall - 具有内部和外部网络的透明代理 Squid
- ios - 有没有更简单的方法来替换在 xcode 中设置到 xib 中的所有字体?
- postman - 邮递员休息客户端每次响应时都会给出 404 错误
- php - 如果我们在 multiselect 上写任何东西,如何将文本插入数据库表?
- r - 使用包 DLM 和 FKF 的状态空间
- java - 如何在“X”号之后显示插页式广告。应用中的点击次数?
- java - 用没有多个发送键的文本填充所有文本区域(循环)?
- css - 动态圆圈导航菜单