python - 无法使用 .backwards() 在 Pytorch 模块中反向传播;权重没有更新
问题描述
我愿意创建一个网络来估计信号的下一个样本。我用一个简单的罪信号来说明。但是当我运行代码时,我得到了噪音作为输出。然后检查图层权重并发现它们没有更新。我在这里找不到错误。
class Model(nn.Module):
def __init__(self,in_dim,hidden_dim,num_classes):
super(Model, self).__init__()
self.layer1 = nn.Linear(in_dim,hidden_dim)
self.layer2 = nn.Linear(hidden_dim,hidden_dim)
self.layer3 = nn.Linear(hidden_dim,num_classes)
self.relu = nn.ReLU()
def forward(self,x):
a = self.relu(self.layer1(x))
a = self.relu(self.layer2(a))
return self.relu(self.layer3(a))
火车:
def train(epoch,L,depth):
criteria = nn.MSELoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
t = np.linspace(0,2,L+2)
fs = L+2
trn_loss = list()
for f in range(0,epoch):
phase = f/np.pi
x = np.sin(2*np.pi*t*fs+phase)
x = torch.from_numpy(x).detach().float()
optimizer.zero_grad()
x_hat = model(x[:-2])
currentCost = criteria(x_hat,x[-2])
trn_loss.append(currentCost.item())
print(model.layer1.weight.data.clone())
currentCost.backward()
optimizer.step()
print(model.layer1.weight.data.clone())
sys.exit('DEBUG')
输出:
tensor([[-0.1715, -0.1696, 0.0424, ..., 0.0154, 0.1450, -0.0544],
[ 0.0368, 0.1427, -0.1419, ..., 0.0966, 0.0298, -0.0659],
[-0.1641, -0.1551, 0.0570, ..., -0.0227, -0.1426, -0.0648],
...,
[-0.0684, -0.1707, -0.0711, ..., 0.0788, 0.1386, 0.1546],
[ 0.1401, -0.0922, -0.0104, ..., -0.0490, 0.0404, 0.1038],
[-0.0604, -0.0517, 0.0715, ..., -0.1200, 0.0014, 0.0215]])
tensor([[-0.1715, -0.1696, 0.0424, ..., 0.0154, 0.1450, -0.0544],
[ 0.0368, 0.1427, -0.1419, ..., 0.0966, 0.0298, -0.0659],
[-0.1641, -0.1551, 0.0570, ..., -0.0227, -0.1426, -0.0648],
...,
[-0.0684, -0.1707, -0.0711, ..., 0.0788, 0.1386, 0.1546],
[ 0.1401, -0.0922, -0.0104, ..., -0.0490, 0.0404, 0.1038],
[-0.0604, -0.0517, 0.0715, ..., -0.1200, 0.0014, 0.0215]])
解决方案
forward
您在通话中的最后一层使用ReLU
激活。这将网络的输出限制在[0, +inf)
范围内。
请注意您的目标在[-1, 1]
范围内,因此网络无法输出一半(负)值(对于正部分,它必须将+inf
可能的值压缩到[0, 1]
空间中)。
您应该更改return self.relu(self.layer3(a))
为return self.layer3(a)
in forward
。
更好的是,为了帮助您的网络适应[-1, 1]
范围,请使用torch.tanh
激活,这样return torch.tanh(self.layer3(a))
应该效果最好。
推荐阅读
- reactjs - React 应用程序无法编译,似乎 typescript 没有被转译为 JS?
- javascript - Vue 测试 - '... .push 不是函数'
- scala - 根据RDD上的Userid计算用户评分平均值和映射
- html - 引导轮播卡从下到上定位
- node.js - res.sendFile, res.download PDF 文件
- typescript - 仅当道具名称是类字段时,如何将道具的值添加到类的实例
- python - Django TinyMCE 添加自定义样式格式
- c++ - 如何在将我的类分成 .h 和 .cpp 文件时使用聚合?
- ssl - MQTT 和 SSL/TLS
- jmeter-5.0 - 无法从 json 中提取值并在后续请求中使用它