python - 哪种损失函数合适?
问题描述
我正在使用 PyTorch,但对图书馆还是很陌生。
我的输入和输出之间的关系由 给出y = ax + b
,其中a
和b
是从某个分布(例如均匀)中采样的,也就是说,它们是随机的。我想训练一个网络x
在看到y
和时进行预测a
。我正在使用一个名为 的网络probability_network
,它带有nn.Linear
层。有N
(比如 10 个)类可供选择x
。
class ProabilityNetwork(nn.Module):
def __init__(self):
super(ProabilityNetwork, self).__init__()
self.fc1 = nn.Linear(8, 76)
self.fc2 = nn.Linear(76, 150)
self.fc3 = nn.Linear(150, 75)
self.fc4 = nn.Linear(75, 14)
self.fc5 = nn.Linear(14, 10)
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
def forward(self, inputs):
return self.softmax(self.fc5(self.fc4(self.relu(self.fc3(self.relu(self.fc2(self.relu(self.fc1(inputs)))))))))
probabilty_network = ProbabilityNetwork()
一看y
,损失函数应该帮助网络预测一个x
最小化的||y-ax||^2
。中的所有数量y = ax + b
都是向量(在此示例中,每个长度为 4)。我已经尝试过以下损失函数。
prob_values = probabilty_network(torch.cat([y, a], dim=0)) # shape: (batch_size, 10)
x_hat = mapping_tensor[torch.argmax(prob_values, dim=1)] # Mapping from probability to one of 10 classes, mapping_tensor is an array of shape (10, 4)
mse_loss = nn.MSELoss()
loss = mse_loss(y, a*x_hat)
例如,mapping_tensor
可以包含从0
( 0000
) 到9
( 1001
) 的值的二进制表示。我需要类的二进制表示的原因是我需要一个向量x
来表示 loss ||y-ax||^2
。在这种情况下,x
是一个4
长度向量,而神经网络的输出是一个10
长度向量。
上面的设置不起作用。预测类中的一半值(以二进制形式写出)总是错误的,这意味着网络在训练时很混乱。
此外,这不是一个无法实现的问题。存在上述损失函数的解决方案(当然有误差,但误差远小于 50%),但计算量很大。我正在尝试检查网络是否可以以某种方式学会以较低的复杂性进行预测。任何帮助表示赞赏。谢谢。
此外,从优化的角度来看,损失函数是最好的解决方案(据我所知)。所以,改变损失函数只会导致更糟糕的结果。
另一种看待问题的方法如下。假设网络看到y
和a
。然后网络计算|y-ax|
每个类x
(从 10 个可能的类中),然后选择具有最小计算值的类。我的问题是,我可以使用什么损失函数来以这种方式训练网络?
解决方案
我根据这个陈述回答:
我想训练一个网络来
x
预测y
。
其中y = ax + b
、a
和b
随机向量(乘性和加性噪声)。
您可以以有监督的方式训练您的模型。鉴于y
,您的模型预测x_pred
。然后将损失函数定义为您的预测x_pred
与基本事实之间的欧几里得距离x
:
loss = torch.nn.functional.mse_loss(x_pred, x)
推荐阅读
- php - 如何使用 php decode 显示来自 json url 的数据?
- python-3.x - 堆参数必须是 python 3 中的列表
- java - 使用原始 python 包而不是 jython 包
- python - 具有两列的熊猫数据框中的自定义数学函数
- loops - Julia for 循环中的迭代索引
- python - 使用 matplotlib scatter 绘制负值
- excel - If 数组中的索引
- javascript - jQuery validate 插件中的文件上传问题
- python - 自定义 django-allauth password_reset_from_key 模板
- amazon-web-services - 将 S3 存储桶安装为 EC2 实例中的驱动器是复制粘贴还是直接将文件保存在存储桶中?