python - Pytorch required_grad=False 在 GPU 上运行时不冻结网络参数
问题描述
在使用 Pytorch 进行训练时,我试图冻结一层玩具模型。在下面的代码中,当我在 CPU 上运行代码时,图层没有更新。(请参阅代码行print("%.8f" % np.max(np.abs(before -after)))
)。但是,当我在 GPU 上运行代码时,图层会更新。我的实施有什么问题?
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
def toNP(x):
return x.detach().to('cpu').numpy()
# toy feed-forward net
class Sub_Net(nn.Module):
def __init__(self):
super(Sub_Net, self).__init__()
self.fc1 = nn.Linear(10, 3)
self.fc2 = nn.Linear(3, 3)
self.fc3 = nn.Linear(3, 3)
self.fc4 = nn.Linear(3, 3)
self.fc5 = nn.Linear(3, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
x = self.fc4(x)
x = self.fc5(x)
return x
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.Sub_Net = Sub_Net()
self.fc1 = nn.Linear(10, 3)
self.fc2 = nn.Linear(3, 3)
self.fc3 = nn.Linear(3, 3)
self.fc4 = nn.Linear(3, 3)
self.fc5 = nn.Linear(3, 1)
def forward(self, x):
y = self.Sub_Net(x)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
x = self.fc4(x)
x = self.fc5(x)
return x+y
def generator_step(net, optimizer, criterion, input, target):
output = net(input)
loss = criterion(output, target)
net.zero_grad()
loss.backward()
optimizer.step()
def discrimination_step(net, optimizer, criterion, input, target):
for param in net.Sub_Net.parameters():
param.requires_grad = False
before = toNP(net.Sub_Net.fc2.weight)
output = net(input)
loss = criterion(output, target)
net.zero_grad()
loss.backward()
optimizer.step()
after = toNP(net.Sub_Net.fc2.weight)
print("%.8f" % np.max(np.abs(before -after)) )
# Run model on GPU
# net = Net().type(torch.cuda.FloatTensor)
# random_input = Variable(torch.randn(10, )).cuda()
# random_target = Variable(torch.randn(1, )).cuda()
# Run model on CPU
net = Net()
random_input = Variable(torch.randn(10, ))
random_target = Variable(torch.randn(1, ))
# loss
criterion = nn.MSELoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.1)
for epoch in range(1, 10):
generator_step(net, optimizer, criterion, random_input, random_target)
discrimination_step(net, optimizer, criterion, random_input, random_target)
在 CPU 上运行时的结果
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
在 GPU 上运行时的结果
0.06700575
0.04242781
0.03090768
0.02379489
0.01885229
0.01519108
0.01237211
0.01014686
0.00836059
解决方案
似乎不是 CPU / GPU 的行为不同,而是与 toNP 函数中的 .to('cpu') 函数有关。
如果给定的张量在 GPU 上,它会在 cpu 上返回复制的张量,而当给定的张量已经在 CPU 上时,它会返回给定的原始对象。请参阅本网站的更多内容。
为了澄清,我在您的discriminator_step中添加了打印功能,如下所示:
def discrimination_step(net, optimizer, criterion, input, target):
for param in net.Sub_Net.parameters():
param.requires_grad = False
before = toNP(net.Sub_Net.fc2.weight)
print(f'Before:\n{before}')
output = net(input)
loss = criterion(output, target)
net.zero_grad()
loss.backward()
optimizer.step()
print(f'Before:\n{before}')
after = toNP(net.Sub_Net.fc2.weight)
print(f'After:\n{after}')
print("dff: %.8f" % np.max(np.abs(before -after)) )
然后代码在 CPU 上产生结果(1 epoch):
Before:
[[-0.0222426 0.06449176 0.41833472]
[-0.3276776 -0.22486973 0.38021228]
[-0.37726757 0.26268137 -0.05000275]]
Before:
[[ 0.04476321 0.13149747 0.48534054]
[-0.2606718 -0.15786391 0.44721812]
[-0.31026173 0.32968715 0.01700307]]
After:
[[ 0.04476321 0.13149747 0.48534054]
[-0.2606718 -0.15786391 0.44721812]
[-0.31026173 0.32968715 0.01700307]]
dff: 0.00000000
在 GPU 上:
Before:
[[-0.06808002 0.39740798 0.55723506]
[-0.17421165 -0.36702433 -0.4208245 ]
[-0.37865937 -0.52346057 -0.15856335]]
Before:
[[-0.06808002 0.39740798 0.55723506]
[-0.17421165 -0.36702433 -0.4208245 ]
[-0.37865937 -0.52346057 -0.15856335]]
After:
[[-0.13508584 0.4644138 0.6242409 ]
[-0.24121748 -0.30001852 -0.35381868]
[-0.31165355 -0.5904664 -0.22556916]]
dff: 0.06700583
它显示在值更改之前,因为在 CPU 上返回的之前张量与net.Sub_Net.fc2.weight共享相同的存储。无论 CPU 和 GPU 如何,层都会更新,因为它们已经在 Adam 优化器的参数组中。
推荐阅读
- react-native - Lottie 动画在 android 中隐藏图层。(反应原生)
- jquery - 隔离 jquery 切换多个按钮
- r - 如何运行滚动窗口面板回归?
- r - 从文件夹中读取文件,尽管索引正确,但我收到“文件错误(文件,“rt”):无法打开连接”错误
- ios - 如何在 Swift 中显示日文字符?
- c# - 如何检查当前可执行文件是否已经绑定了 Windows 服务
- javascript - 计算列表项数量的问题
- html - 带有图像背景 CSS 的 Html
- r - (快速)对元素具有“a/b”格式的矩阵列进行成对比较
- javascript - 单击复选框时添加和删除字符串中的值