pytorch - 使用 BatchNorm 的 Pytorch Gradient wrt 输入
问题描述
我正在尝试计算一个简单神经网络的输出相对于输入的梯度。当我不使用 BatchNorm 图层时,结果看起来不错。一旦我使用它,结果似乎没有多大意义。下面是一个重现效果的简短示例。
class Net(nn.Module):
def __init__(self, batch_norm):
super().__init__()
self.batch_norm = batch_norm
self.act_fn = nn.Tanh()
self.aff1 = nn.Linear(1, 10)
self.aff2 = nn.Linear(10, 1)
if batch_norm:
self.bn = nn.BatchNorm1d(10, affine=False) # False for simplicity
def forward(self, x):
x = self.aff1(x)
x = self.act_fn(x)
if self.batch_norm:
x = self.bn(x)
x = self.aff2(x)
return x
x_vals = torch.linspace(0, 1, 100)
x_vals.requires_grad = True
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
for seed, bn, ax1 in zip([11, 4], [False, True], axs): # different seeds for better illustration of effect
torch.manual_seed(seed)
net = Net(batch_norm=bn)
net.train()
pred = net(x_vals[:, None])
pred_dx = torch.autograd.grad(pred.sum(), x_vals, create_graph=True)[0]
# visualization
ax2 = ax1.twinx()
ax1.plot(x_vals.detach(), pred.detach())
ax2.plot(x_vals.detach(), pred_dx.detach(), linestyle='--', color='orange')
min_idx = torch.argmin((pred[1:]-pred[:-1])**2)
ax2.axvline(x_vals[min_idx].detach(), color='gray', linestyle='dotted')
ax2.axhline(0, color='gray', linestyle='dotted')
ax1.set_title(('With' if bn else 'Without') + ' Batch Norm')
plt.show()
当我使用评估模式时,结果似乎也很好。不幸的是,我不能只切换到 eval() 模式,因为我的问题(PINN)的性质需要在训练期间计算梯度。
我了解在训练期间会更新运行均值和方差。也许这有影响?我还能以某种方式获得正确的渐变吗?
谢谢你的帮助!
解决方案
由于使用了“.train()”这个词,这有点令人困惑,但是:
- net.train() 将批量标准化和 dropout 等层置于活动状态。所以是的,运行均值和方差将被更新,但梯度仍将被计算。
- 使用 torch.no_grad() 或将变量的 require_grad 设置为 False 将阻止任何梯度计算
因此,如果您需要“训练”批量归一化,您将无法在不受批量归一化影响的情况下获得梯度。如果您不需要训练它们,只需将模型置于评估模式,因为梯度仍在计算中
推荐阅读
- python - 使用烧瓶传递变量python Web应用程序时出错
- c# - 内存池之间的区别
和数组池 - c# - 将组合键与 Entity Framework Core 一起使用并将其中的一部分用作外键
- flutter - 从颤动中运行 cURL 推荐
- python - 神经网络决策边界垂直于真实边界
- java - 配置覆盖默认应用程序上下文 bean 的 Spring 的 @DataJpaTest
- javascript - 在儿童内部设置状态与作为道具发送有什么区别?
- python - 从python中的列表中遍历树进行计算?
- python - 在 python 中写入串行端口以供 arduino 使用的问题(打开灯/移动伺服等...)
- python - 为循环 PyGame 绘制双倍加速