python - 模型权重没有更新,但损失正在减少
问题描述
以下代码是用大小为 64*64 的图像训练 MLP,同时使用损失 ||output - input||^2。
出于某种原因,我每个时期的权重没有更新,如最后所示。
class MLP(nn.Module):
def __init__(self, size_list):
super(MLP, self).__init__()
layers = []
self.size_list = size_list
for i in range(len(size_list) - 2):
layers.append(nn.Linear(size_list[i],size_list[i+1]))
layers.append(nn.ReLU())
layers.append(nn.Linear(size_list[-2], size_list[-1]))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
model_1 = MLP([4096, 64, 4096])
对于每个时期的训练:
def train_epoch(model, train_loader, criterion, optimizer):
model.train()
model.to(device)
running_loss = 0.0
start_time = time.time()
# train batch
for batch_idx, (data) in enumerate(train_loader):
optimizer.zero_grad()
data = data.to(device)
outputs = model(data)
loss = criterion(outputs, data)
running_loss += loss.item()
loss.backward()
optimizer.step()
end_time = time.time()
weight_ll = model.net[0].weight
running_loss /= len(train_loader)
print('Training Loss: ', running_loss, 'Time: ',end_time - start_time, 's')
return running_loss, outputs, weight_ll
用于训练数据:
n_epochs = 20
Train_loss = []
weights=[]
criterion = nn.MSELoss()
optimizer = optim.SGD(model_1.parameters(), lr = 0.1)
for i in range(n_epochs):
train_loss, output, weights_ll = train_epoch(model_1, trainloader, criterion, optimizer)
Train_loss.append(train_loss)
weights.append(weights_ll)
print('='*20)
现在,当我打印每个时期的第一个全连接层的权重时,它们并没有被更新。
print(weights[0][0])
print(weights[19][0])
上面的输出是(显示 epoch 0 和 epoch 19 的权重):
tensor([ 0.0086, 0.0069, -0.0048, ..., -0.0082, -0.0115, -0.0133],
grad_fn=<SelectBackward>)
tensor([ 0.0086, 0.0069, -0.0048, ..., -0.0082, -0.0115, -0.0133],
grad_fn=<SelectBackward>)
可能出了什么问题?看看我的损失,它以稳定的速度减少,但权重没有变化。
解决方案
尝试更改它weight_ll = model.net[0].weight.clone().detach()
或仅weight_ll = model.net[0].weight.clone()
在您的train_epoch()
功能中进行更改。你会看到权重不同。
说明:weights_ll
如果您不克隆它,则始终是最后一个 epoch 值。它将被视为图中相同的张量。这就是为什么你的weights[0][0]
equals weights[19][0]
,它们实际上是同一个张量。
推荐阅读
- c# - 文件名、目录名或卷标语法不正确,File.Copy
- html - 使用超链接在复选框 hack 中切换复选框的状态
- qt - 为什么 QML 关键事件不能被拦截?
- c++ - MKL 矩形矩阵就地转置:不使用多核?
- php - 如何使用 laravel 5 通过 jQuery ajax 删除记录?
- gulp - gulp-scss-lint 配置文件未验证 scss 文件并显示警告
- java - 将 while 循环中的 AudioRecord 缓冲区写入数组
- php - 如何从上次交易中获得价值,但不是今天的交易
- android - 如何在 Fragment Android 中编写 SQLite 数据库
- android - Androidx Cardview 上未显示阴影