python - 定义损失函数,以便使用外部数组
问题描述
在我的神经网络 (RNN) 中,我定义了损失函数,以便神经网络的输出用于查找索引(二进制),然后索引用于从数组中提取所需的元素,这反过来将是用于计算 MSELoss。
但是,该程序给出了parameter().grad = None
错误,这主要是因为图表在某处中断。定义的错误函数有什么问题。
框架:Pytorch
代码如下: 神经网络:
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.hidden_size = 8
# self.input_size = 2
self.h2o = nn.Linear(self.hidden_size, 1)
self.h2h = nn.Linear(self.hidden_size, self.hidden_size)
self.sigmoid = nn.Sigmoid()
def forward(self,hidden):
output = self.h2o(hidden)
output = self.sigmoid(output)
hidden = self.h2h(hidden)
return output, hidden
def init_hidden(self):
return torch.zeros(1, self.hidden_size)
损失函数、训练步骤和训练
rnn = RNN()
criterion = nn.MSELoss()
def loss_function(previous, output, index):
code = 2*(output > 0.5).long()
current = Q_m2[code:code+2, i]
return criterion(current, previous), current
def train_step():
hidden = rnn.init_hidden()
rnn.zero_grad()
# Q_m2.requires_grad = True
# Q_m2.create_graph = True
loss = 0
previous = Q_m[0:2, 0]
for i in range(1, samples):
output, hidden = rnn(hidden)
l, previous = loss_function(previous, output, i)
loss+=l
loss.backward()
# Q_m2.retain_grad()
for p in rnn.parameters():
p.data.add_(p.grad.data, alpha=-0.05)
return output, loss.item()/(samples - 1)
def training(epochs):
running_loss = 0
for i in range(epochs):
output, loss = train_step()
print(f'Epoch Number: {i+1}, Loss: {loss}')
running_loss +=loss
Q_m2
Q_m = np.zeros((4, samples))
for i in range(samples):
Q_m[:,i] = q_x(U_m[:,i])
Q_m = torch.FloatTensor(Q_m)
Q_m2 = Q_m
Q_m2.requires_grad = True
Q_m2.create_graph = True
错误:
<ipython-input-36-feefd257c97a> in train_step()
21 # Q_m2.retain_grad()
22 for p in rnn.parameters():
---> 23 p.data.add_(p.grad.data, alpha=-0.05)
24 return output, loss.item()/(samples - 1)
25
AttributeError: 'NoneType' object has no attribute 'data'
解决方案
这是K. Frank在讨论.pytorch.org向我建议的一个可能的解决方案
当我读到它时,代码被计算为 0 或 2。您可以将输出(根据需要进行适当处理)解释为代码应该为 0 与 2 的概率,然后使用该概率形成加权平均值Q_m2 数组中的 0 和 2 条目。
推荐阅读
- php - 如何使用 HTTPRequest 使用 Flutter 登录 Laravel
- wpf - 可重用用户控件中的 WPF 验证错误样式?
- mongodb - 无法在 alpine linux 中安装 php7-mongodb
- java - JSP 页面中未显示的变量
- php - 选项值与正确值不匹配
- php - 如何通过 __set() 方法使用 PDO::FETCH_CLASS?
- android - 无法使用 com.tns 查找课程
- sql - 将姐妹表值应用于主表中的所有父/子关系
- automation - 如何使用 botium 绑定更改 Botium 测试的测试套件名称。当前它是默认名称
- linux - 不使用pop操作读取数据有优势吗?