首页 > 解决方案 > 使用 Pytorch 进行批量归一化逆计算

问题描述

我已经实现了一个 BatchNorm 类来计算 BatchNormalization 及其逆,当我用 1 批次的张量对其进行测试时,它可以正常工作,但是当我针对多批次张量对其进行测试时,它就不能正常工作。代码:

class BatchNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.zeros(1, dim), requires_grad=True)
        self.beta = nn.Parameter(torch.zeros(1, dim), requires_grad=True)
        self.batch_mean = None
        self.batch_var = None

    def forward(self, x, reverse=False):
        B, C, W, H = x.shape
        if(reverse == True):
            return self.reverse(x)
        if self.training:
            if(B>1):
                m = x.mean(dim=0)
                v = x.var(dim=0) + self.eps  # torch.mean((x - m) ** 2, axis=0) + self.eps
            else:
                m = torch.zeros(C, W, H)
                v = torch.zeros(C, W, H) + self.eps
            self.batch_mean = None
        else:
            if self.batch_mean is None:
                self.set_batch_stats_func(x)
            m = self.batch_mean.clone()
            v = self.batch_var.clone()
        B, C, W, H = x.shape
        gamma = self.gamma.unsqueeze(2).unsqueeze(3)
        gamma = torch.repeat_interleave(gamma, H, dim=2)
        gamma = torch.repeat_interleave(gamma, W, dim=3)
        beta = self.beta.unsqueeze(2).unsqueeze(3)
        beta = torch.repeat_interleave(beta, H, dim=2)
        beta = torch.repeat_interleave(beta, W, dim=3)
        #print('x_hat:', x_hat)
        x_hat = (x - m) / torch.sqrt(v)
        x_hat = x_hat * torch.exp(gamma) + beta
        
        x_2 = (x_hat - beta) * torch.exp(-gamma) * torch.sqrt(v) + m
        #print('forward: dist:', torch.dist(x, x_2))
        #print('forward: x:', x[0,0,:3,:3])
        #print('forward: x_2:', x_2[0,0,:3,:3])
        #print('forward: x_hat:', x_hat[0,0,:3,:3])
        
        log_det = torch.sum(gamma - 0.5 * torch.log(v))
        return x_hat, log_det

    def reverse(self, x):
        B, C, W, H = x.shape
        if self.training:
            if(B>1):
                m = x.mean(dim=0)
                v = x.var(dim=0) + self.eps  # torch.mean((x - m) ** 2, axis=0) + self.eps
            else:
                m = torch.zeros(C, W, H)
                v = torch.zeros(C, W, H) + self.eps
            self.batch_mean = None
        else:
            if self.batch_mean is None:
                self.set_batch_stats_func(x)
            m = self.batch_mean
            v = self.batch_var

        B, C, W, H = x.shape
        gamma = self.gamma.unsqueeze(2).unsqueeze(3)
        gamma = torch.repeat_interleave(gamma, H, dim=2)
        gamma = torch.repeat_interleave(gamma, W, dim=3)
        beta = self.beta.unsqueeze(2).unsqueeze(3)
        beta = torch.repeat_interleave(beta, H, dim=2)
        beta = torch.repeat_interleave(beta, W, dim=3)
        x_hat = (x - beta) * torch.exp(-gamma) * torch.sqrt(v) + m
        #print('reverse: dist:', torch.dist(x, x_hat))
        #print('reverse: x:', x[0,0,:3,:3])
        #print('reverse: x_hat:', x_hat[0,0,:3,:3])
        log_det = torch.sum(-gamma + 0.5 * torch.log(v))
        return x_hat, log_det

    def set_batch_stats_func(self, x):
        print("setting batch stats for validation")
        self.batch_mean = x.mean(dim=0)
        self.batch_var = x.var(dim=0) + self.eps

对一批张量进行测试:

x = torch.rand(1,10,100,100)
Batch = BatchNorm(10)
x1,_ = Batch(x, False)
x2,_ = Batch(x1, True)
torch.dist(x,x2)

并且输出大约为零,这意味着前向和后向路径都正常工作,但对于多批次张量:

x = torch.rand(3,10,100,100)
Batch = BatchNorm(10)
x1,_ = Batch(x, False)
x2,_ = Batch(x1, True)
torch.dist(x,x2)

在这种情况下,结果(输入和重构输入之间的差异)是一个巨大的数字。但是,它必须接近于零。

标签: deep-learningneural-networkpytorchinversebatch-normalization

解决方案


推荐阅读