python - 用于图像重建的自动编码器产生带有奇怪网格的灰色图像
问题描述
我正在尝试使用自动编码器创建 Deep Fake。我使用一个编码器和两个解码器:一个用于目标图像,另一个用于源图像(目标人脸是我想“粘贴”在源头上的人脸)。所以首先我试图训练编码器和两个解码器来重建输入面(300、300、3)。但是对于这两个解码器,输出的不是彩色图像而是灰色图像,因为对于每个像素,红色、绿色和蓝色值几乎相同。除此之外,输出图像上还有一个奇怪的 3x3 网格:
我使用的批量大小为 1,因为在这种情况下我不知道如何使用小批量(但那是另一个问题)。我还使用了残差连接,这提高了质量。最后一层有一个 sigmoid 激活(可能是错误的)。我的损失是二元交叉熵,优化器是亚当。学习率为 0.001(我也试过 0.0001 和 0.00075)。
这是我的模型:
import matplotlib.pyplot as plt
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
""" encoder """
self.conv1 = nn.Conv2d(3, 32, kernel_size=(4, 4))
self.batchnorm1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=(4, 4))
self.batchnorm2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, kernel_size=(3, 3))
self.batchnorm3 = nn.BatchNorm2d(128)
self.conv4 = nn.Conv2d(128, 256, kernel_size=(4, 4))
self.batchnorm4 = nn.BatchNorm2d(256)
self.maxpool3x3 = nn.MaxPool2d(3)
self.maxpool2x2 = nn.MaxPool2d(2)
""" target-decoder """
self.targetDeconv1 = nn.ConvTranspose2d(256, 128, kernel_size=(4, 4))
self.targetBatchnorm1 = nn.BatchNorm2d(128)
self.targetDeconv2 = nn.ConvTranspose2d(128, 64, kernel_size=(3, 3))
self.targetBatchnorm2 = nn.BatchNorm2d(64)
self.targetDeconv3 = nn.ConvTranspose2d(64, 32, kernel_size=(4, 4))
self.targetBatchnorm3 = nn.BatchNorm2d(32)
self.targetDeconv4 = nn.ConvTranspose2d(32, 3, kernel_size=(4, 4))
self.upsample3x3 = nn.Upsample(scale_factor=3)
self.upsample2x2 = nn.Upsample(scale_factor=2)
""" source-decoder """
self.sourceDeconv1 = nn.ConvTranspose2d(256, 128, kernel_size=(4, 4))
self.sourceBatchnorm1 = nn.BatchNorm2d(128)
self.sourceDeconv2 = nn.ConvTranspose2d(128, 64, kernel_size=(3, 3))
self.sourceBatchnorm2 = nn.BatchNorm2d(64)
self.sourceDeconv3 = nn.ConvTranspose2d(64, 32, kernel_size=(4, 4))
self.sourceBatchnorm3 = nn.BatchNorm2d(32)
self.sourceDeconv4 = nn.ConvTranspose2d(32, 3, kernel_size=(4, 4))
self.upsample3x3 = nn.Upsample(scale_factor=3)
self.upsample2x2 = nn.Upsample(scale_factor=2)
def _visualize_features(self, feature_maps, dim: tuple=(), title: str=""):
try:
x, y = dim
fig, axs = plt.subplots(x, y)
c = 0
for i in range(x):
for j in range(y):
axs[i][j].matshow(feature_maps.detach().cpu().numpy()[0][c])
c += 1
fig.suptitle(title)
plt.show()
except:
pass
def forward(self, x, label: str="0", visualize: bool=False):
""" encoder """
x = self.conv1(x)
x = self.batchnorm1(x)
x = F.relu(x)
x_1 = self.maxpool3x3(x)
if visualize: print(x_1.shape); self._visualize_features(x_1, dim=(4, 4))
x = self.conv2(x_1)
x = self.batchnorm2(x)
x = F.relu(x)
x_2 = self.maxpool3x3(x)
if visualize: print(x_2.shape); self._visualize_features(x_2, dim=(4, 4))
x = self.conv3(x_2)
x = self.batchnorm3(x)
x = F.relu(x)
x_3 = self.maxpool2x2(x)
if visualize: print(x_3.shape); self._visualize_features(x_3, dim=(4, 4))
x = self.conv4(x_3)
x = self.batchnorm4(x)
x = F.relu(x)
x = self.maxpool2x2(x)
if visualize: print(x.shape); self._visualize_features(x, dim=(4, 4))
""" target-decoder """
if label == "0":
x = self.upsample2x2(x)
x = self.targetDeconv1(x)
x += x_3
x = self.targetBatchnorm1(x)
x = F.relu(x)
if visualize: print(x.shape); self._visualize_features(x, dim=(4, 4))
x = self.upsample2x2(x)
x = self.targetDeconv2(x)
x += x_2
x = self.targetBatchnorm2(x)
x = F.relu(x)
if visualize: print(x.shape); self._visualize_features(x, dim=(4, 4))
x = self.upsample3x3(x)
x = self.targetDeconv3(x)
x += x_1
x = self.targetBatchnorm3(x)
x = F.relu(x)
if visualize: print(x.shape); self._visualize_features(x, dim=(4, 4))
x = self.upsample3x3(x)
x = self.targetDeconv4(x)
x = torch.sigmoid(x)
if visualize: print(x.shape); self._visualize_features(x, dim=(3, 1))
return x
""" source-decoder """
if label == "1":
x = self.upsample2x2(x)
x = self.sourceDeconv1(x)
x += x_3
x = self.sourceBatchnorm1(x)
x = F.relu(x)
if visualize: print(x.shape); self._visualize_features(x, dim=(4, 4))
x = self.upsample2x2(x)
x = self.sourceDeconv2(x)
x += x_2
x = self.sourceBatchnorm2(x)
x = F.relu(x)
if visualize: print(x.shape); self._visualize_features(x, dim=(4, 4))
x = self.upsample3x3(x)
x = self.sourceDeconv3(x)
x += x_1
x = self.sourceBatchnorm3(x)
x = F.relu(x)
if visualize: print(x.shape); self._visualize_features(x, dim=(4, 4))
x = self.upsample3x3(x)
x = self.sourceDeconv4(x)
x = torch.sigmoid(x)
if visualize: print(x.shape); self._visualize_features(x, dim=(3, 1))
return x
所以我的问题是:为什么输出图像是灰色的?我的输出图像上的那个网格是什么?这两个问题有关系吗?我检查了它是否与 rgb 与 bgr 有任何关系,但看起来不像。我希望任何人都可以解决我的问题,在此先感谢:)
解决方案
推荐阅读
- python - 如何找到重复患者并添加新列
- list - SwiftUI 如何在其中创建带有自定义 UIViews 的列表
- virtual-machine - 使用 Terraform for Azure 创建多个虚拟机的问题
- r - 使用 nlstools 包或 glm 在 R 中进行逻辑回归?
- c# - 为什么 EF 核心中的 ExecuteSqlRaw() 方法忽略了一些数据库关键字
- android - 如何制作工具栏,使其始终位于其下的所有视图和布局之上?
- javascript - 如何标记分组条形图中的每个条形图?
- azure-iot-hub - 如何使用 azure python sdk 版本 2 将经过 X.509 身份验证的下游设备连接到启用了 azure edge 的网关
- angular - IIS 重定向到新 URL
- python - 通过脚本 Python 在合成部分 Blender 中创建颜色渐变节点