machine-learning - 变分自动编码器的采样问题
问题描述
我的 vae 课程如下所示:
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
c = capacity
self.conv1 = nn.Conv2d(in_channels=1, out_channels=c, kernel_size=4, stride=2, padding=1) # out: c x 14 x 14
self.conv2 = nn.Conv2d(in_channels=c, out_channels=c*2, kernel_size=4, stride=2, padding=1) # out: c x 7 x 7
self.fc_mu = nn.Linear(in_features=c*2*7*7, out_features=latent_dims)
self.fc_logvar = nn.Linear(in_features=c*2*7*7, out_features=latent_dims)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.view(x.size(0), -1) # flatten batch of multi-channel feature maps to a batch of feature vectors
x_mu = self.fc_mu(x)
x_logvar = self.fc_logvar(x)
return x_mu, x_logvar
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
c = capacity
self.fc = nn.Linear(in_features=latent_dims, out_features=c*2*7*7)
self.conv2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=1, kernel_size=4, stride=2, padding=1)
def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), capacity*2, 7, 7) # unflatten batch of feature vectors to a batch of multi-channel feature maps
x = F.relu(self.conv2(x))
x = torch.sigmoid(self.conv1(x)) # last layer before output is sigmoid, since we are using BCE as reconstruction loss
return x
class VariationalAutoencoder(nn.Module):
def __init__(self):
super(VariationalAutoencoder, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
latent_mu, latent_logvar = self.encoder(x)
latent = self.latent_sample(latent_mu, latent_logvar)
x_recon = self.decoder(latent)
return x_recon, latent_mu, latent_logvar
def latent_sample(self, mu, logvar):
if self.training:
# the reparameterization trick
std = logvar.mul(0.5).exp_()
eps = torch.empty_like(std).normal_()
return eps.mul(std).add_(mu)
else:
return mu
def vae_loss(recon_x, x, mu, logvar):
# recon_x is the probability of a multivariate Bernoulli distribution p.
# -log(p(x)) is then the pixel-wise binary cross-entropy.
# Averaging or not averaging the binary cross-entropy over all pixels here
# is a subtle detail with big effect on training, since it changes the weight
# we need to pick for the other loss term by several orders of magnitude.
# Not averaging is the direct implementation of the negative log likelihood,
# but averaging makes the weight of the other loss term independent of the image resolution.
recon_loss = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + variational_beta * kldivergence
我在 MNIST 数据集上训练它。我想对其进行采样,或生成一个数组并将其提供给解码器并查看输出将是什么。
问题是我不太明白,我的 z 数组应该是什么样子以及它应该需要什么形状。
这是采样的代码:
z = ...
input = torch.FloatTensor(z).to(device)
vae.eval()
output = vae.decoder(input)
plot_gallery(output.data.cpu().numpy(), 24, 24, n_row=5, n_col=5)
解决方案
推荐阅读
- python-3.x - 如何在 python kivy 框架中创建音频控制器?
- python-3.x - 文件中的正则表达式字符串替换
- r - 在 devtools::dev_mode() 中使用不同的 R makevars 文件
- sql - 如何在 Azure 数据工厂管道中使用 Try and Catch 命令?
- database - 使用集群时的TDengine查询
- ios - 调度组等待超时
- mysql - Mysql 用 if case 或 Join 连接 3 个表?
- html - 为什么当我将 div 设置为响应式正方形时,div 的背景颜色没有显示?
- java - 二元运算符'<='java错误的错误操作数类型
- nginx - Nginx - nginx:[emerg] 意外“,”