python - 沙漏网络和自动编码器有什么区别?
问题描述
我试图实现沙漏网络,我发现很难理解和可视化下面的代码。谁能显示跳过连接链接如何在下面的网络中连接?
import torch
import torch.nn as nn
class Hourglass(nn.Module):
def __init__(self):
super(Hourglass, self).__init__()
self.leaky_relu = nn.LeakyReLU()
self.d_conv_1 = nn.Conv2d(2, 8, 5, stride=2, padding=2)
self.d_bn_1 = nn.BatchNorm2d(8)
self.d_conv_2 = nn.Conv2d(8, 16, 5, stride=2, padding=2)
self.d_bn_2 = nn.BatchNorm2d(16)
self.d_conv_3 = nn.Conv2d(16, 32, 5, stride=2, padding=2)
self.d_bn_3 = nn.BatchNorm2d(32)
self.s_conv_3 = nn.Conv2d(32, 4, 5, stride=1, padding=2)
self.d_conv_4 = nn.Conv2d(32, 64, 5, stride=2, padding=2)
self.d_bn_4 = nn.BatchNorm2d(64)
self.s_conv_4 = nn.Conv2d(64, 4, 5, stride=1, padding=2)
self.d_conv_5 = nn.Conv2d(64, 128, 5, stride=2, padding=2)
self.d_bn_5 = nn.BatchNorm2d(128)
self.s_conv_5 = nn.Conv2d(128, 4, 5, stride=1, padding=2)
self.d_conv_6 = nn.Conv2d(128, 256, 5, stride=2, padding=2)
self.d_bn_6 = nn.BatchNorm2d(256)
self.u_deconv_5 = nn.ConvTranspose2d(256, 124, 4, stride=2, padding=1)
self.u_bn_5 = nn.BatchNorm2d(128)
self.u_deconv_4 = nn.ConvTranspose2d(128, 60, 4, stride=2, padding=1)
self.u_bn_4 = nn.BatchNorm2d(64)
self.u_deconv_3 = nn.ConvTranspose2d(64, 28, 4, stride=2, padding=1)
self.u_bn_3 = nn.BatchNorm2d(32)
self.u_deconv_2 = nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1)
self.u_bn_2 = nn.BatchNorm2d(16)
self.u_deconv_1 = nn.ConvTranspose2d(16, 8, 4, stride=2, padding=1)
self.u_bn_1 = nn.BatchNorm2d(8)
self.out_deconv = nn.ConvTranspose2d(8, 3, 4, stride=2, padding=1)
self.out_bn = nn.BatchNorm2d(3)
def forward(self, noise):
down_1 = self.d_conv_1(noise)
down_1 = self.d_bn_1(down_1)
down_1 = self.leaky_relu(down_1)
down_2 = self.d_conv_2(down_1)
down_2 = self.d_bn_2(down_2)
down_2 = self.leaky_relu(down_2)
down_3 = self.d_conv_3(down_2)
down_3 = self.d_bn_3(down_3)
down_3 = self.leaky_relu(down_3)
skip_3 = self.s_conv_3(down_3)
down_4 = self.d_conv_4(down_3)
down_4 = self.d_bn_4(down_4)
down_4 = self.leaky_relu(down_4)
skip_4 = self.s_conv_4(down_4)
down_5 = self.d_conv_5(down_4)
down_5 = self.d_bn_5(down_5)
down_5 = self.leaky_relu(down_5)
skip_5 = self.s_conv_5(down_5)
down_6 = self.d_conv_6(down_5)
down_6 = self.d_bn_6(down_6)
down_6 = self.leaky_relu(down_6)
up_5 = self.u_deconv_5(down_6)
up_5 = torch.cat([up_5, skip_5], 1)
up_5 = self.u_bn_5(up_5)
up_5 = self.leaky_relu(up_5)
up_4 = self.u_deconv_4(up_5)
up_4 = torch.cat([up_4, skip_4], 1)
up_4 = self.u_bn_4(up_4)
up_4 = self.leaky_relu(up_4)
up_3 = self.u_deconv_3(up_4)
up_3 = torch.cat([up_3, skip_3], 1)
up_3 = self.u_bn_3(up_3)
up_3 = self.leaky_relu(up_3)
up_2 = self.u_deconv_2(up_3)
up_2 = self.u_bn_2(up_2)
up_2 = self.leaky_relu(up_2)
up_1 = self.u_deconv_1(up_2)
up_1 = self.u_bn_1(up_1)
up_1 = self.leaky_relu(up_1)
out = self.out_deconv(up_1)
out = self.out_bn(out)
out = nn.Sigmoid()(out)
return out
自动编码器和沙漏网络有什么区别,对我来说两者看起来几乎一样,请对此有所了解。如果两者都是相同的编码器和解码器架构,那么为什么沙漏会这样命名?
解决方案
推荐阅读
- http - Flutter HttpException:接收数据时连接关闭
- git - Bitbucket Pipelines 随机失败 - “无法克隆存储库”错误
- gnuplot - 使用 gnuplot 从离散数据中绘制平滑球体
- reactjs - 如何在反应中将页面呈现到基本网址
- android - 如何在flutter中调用runApp方法中的两个类
- angular - Angular 2+,仅在条件适用但显示子元素时显示容器元素
- forms - Symfony 中的多种形式。显示有效,但多次持续存在
- vue.js - 如何在不停止执行测试的情况下等待一些文本?
- amazon-web-services - AWS Workspaces 和启用 RDP 的 EC2 实例有什么区别?
- c++ - 使用cmake错误链接boost日志