deep-learning - 使用 Pytorch 进行批量归一化逆计算
问题描述
我已经实现了一个 BatchNorm 类来计算 BatchNormalization 及其逆,当我用 1 批次的张量对其进行测试时,它可以正常工作,但是当我针对多批次张量对其进行测试时,它就不能正常工作。代码:
class BatchNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.zeros(1, dim), requires_grad=True)
self.beta = nn.Parameter(torch.zeros(1, dim), requires_grad=True)
self.batch_mean = None
self.batch_var = None
def forward(self, x, reverse=False):
B, C, W, H = x.shape
if(reverse == True):
return self.reverse(x)
if self.training:
if(B>1):
m = x.mean(dim=0)
v = x.var(dim=0) + self.eps # torch.mean((x - m) ** 2, axis=0) + self.eps
else:
m = torch.zeros(C, W, H)
v = torch.zeros(C, W, H) + self.eps
self.batch_mean = None
else:
if self.batch_mean is None:
self.set_batch_stats_func(x)
m = self.batch_mean.clone()
v = self.batch_var.clone()
B, C, W, H = x.shape
gamma = self.gamma.unsqueeze(2).unsqueeze(3)
gamma = torch.repeat_interleave(gamma, H, dim=2)
gamma = torch.repeat_interleave(gamma, W, dim=3)
beta = self.beta.unsqueeze(2).unsqueeze(3)
beta = torch.repeat_interleave(beta, H, dim=2)
beta = torch.repeat_interleave(beta, W, dim=3)
#print('x_hat:', x_hat)
x_hat = (x - m) / torch.sqrt(v)
x_hat = x_hat * torch.exp(gamma) + beta
x_2 = (x_hat - beta) * torch.exp(-gamma) * torch.sqrt(v) + m
#print('forward: dist:', torch.dist(x, x_2))
#print('forward: x:', x[0,0,:3,:3])
#print('forward: x_2:', x_2[0,0,:3,:3])
#print('forward: x_hat:', x_hat[0,0,:3,:3])
log_det = torch.sum(gamma - 0.5 * torch.log(v))
return x_hat, log_det
def reverse(self, x):
B, C, W, H = x.shape
if self.training:
if(B>1):
m = x.mean(dim=0)
v = x.var(dim=0) + self.eps # torch.mean((x - m) ** 2, axis=0) + self.eps
else:
m = torch.zeros(C, W, H)
v = torch.zeros(C, W, H) + self.eps
self.batch_mean = None
else:
if self.batch_mean is None:
self.set_batch_stats_func(x)
m = self.batch_mean
v = self.batch_var
B, C, W, H = x.shape
gamma = self.gamma.unsqueeze(2).unsqueeze(3)
gamma = torch.repeat_interleave(gamma, H, dim=2)
gamma = torch.repeat_interleave(gamma, W, dim=3)
beta = self.beta.unsqueeze(2).unsqueeze(3)
beta = torch.repeat_interleave(beta, H, dim=2)
beta = torch.repeat_interleave(beta, W, dim=3)
x_hat = (x - beta) * torch.exp(-gamma) * torch.sqrt(v) + m
#print('reverse: dist:', torch.dist(x, x_hat))
#print('reverse: x:', x[0,0,:3,:3])
#print('reverse: x_hat:', x_hat[0,0,:3,:3])
log_det = torch.sum(-gamma + 0.5 * torch.log(v))
return x_hat, log_det
def set_batch_stats_func(self, x):
print("setting batch stats for validation")
self.batch_mean = x.mean(dim=0)
self.batch_var = x.var(dim=0) + self.eps
对一批张量进行测试:
x = torch.rand(1,10,100,100)
Batch = BatchNorm(10)
x1,_ = Batch(x, False)
x2,_ = Batch(x1, True)
torch.dist(x,x2)
并且输出大约为零,这意味着前向和后向路径都正常工作,但对于多批次张量:
x = torch.rand(3,10,100,100)
Batch = BatchNorm(10)
x1,_ = Batch(x, False)
x2,_ = Batch(x1, True)
torch.dist(x,x2)
在这种情况下,结果(输入和重构输入之间的差异)是一个巨大的数字。但是,它必须接近于零。
解决方案
推荐阅读
- angular - Angular(不是AngularJs)中angular-tiny-calendar包的用途是什么?
- haskell - 有没有办法在 Haskell 中对长度进行模式匹配?
- linux - 无法编译 android 内核:在 built-in.o 中停止
- variables - jq 的变量绑定是否通过函数起作用?
- templates - 使用 VueJs 计算“脱离 DOM”的模板绑定
- java - 如何仅从具有重复项的列表中删除一项?
- keyerror - Python KeyError:我想不通(python 新手)
- java - Spring 5 AOP:不执行建议
- python - 使用 TensorFlow 预测新数据
- r - 边的子集