python - PyTorch 二进制分类 - 相同的网络结构,“更简单”的数据,但性能更差?
问题描述
为了掌握 PyTorch(以及一般的深度学习),我首先研究了一些基本的分类示例。一个这样的例子是对使用 sklearn 创建的非线性数据集进行分类(完整代码可在此处作为笔记本获取)
n_pts = 500
X, y = datasets.make_circles(n_samples=n_pts, random_state=123, noise=0.1, factor=0.2)
x_data = torch.FloatTensor(X)
y_data = torch.FloatTensor(y.reshape(500, 1))
然后使用非常基本的神经网络对其进行准确分类
class Model(nn.Module):
def __init__(self, input_size, H1, output_size):
super().__init__()
self.linear = nn.Linear(input_size, H1)
self.linear2 = nn.Linear(H1, output_size)
def forward(self, x):
x = torch.sigmoid(self.linear(x))
x = torch.sigmoid(self.linear2(x))
return x
def predict(self, x):
pred = self.forward(x)
if pred >= 0.5:
return 1
else:
return 0
由于我对健康数据感兴趣,因此我决定尝试使用相同的网络结构对一些基本的现实世界数据集进行分类。我从这里获取了一名患者的心率数据,并对其进行了更改,因此所有 > 91 的值都将被标记为异常(例如 a1
和所有 <= 91 的值都标记为 a 0
)。这完全是任意的,但我只是想看看分类是如何工作的。此示例的完整笔记本在这里。
对我来说不直观的是为什么第一个示例在 1,000 个 epochs 后达到 0.0016 的损失,而第二个示例在 10,000 个 epochs 后仅达到 0.4296 的损失
也许我天真地认为心率示例会更容易分类。任何有助于我理解为什么这不是我所看到的见解都会很棒!
解决方案
TL;博士
您的输入数据未标准化。
- 利用
x_data = (x_data - x_data.mean()) / x_data.std()
- 提高学习率
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
仅在 1000 次迭代中收敛。
更多细节
您拥有的两个示例之间的主要区别在于,x
第一个示例中的数据以 (0, 0) 为中心,并且方差非常低。
另一方面,第二个例子中的数据以 92 为中心,方差比较大。
当您随机初始化权重时,不会考虑数据中的这种初始偏差,这是基于输入大致正态分布在零附近的假设而完成的。
优化过程几乎不可能补偿这种严重偏差 - 因此模型陷入次优解决方案。
一旦对输入进行归一化,通过减去均值并除以标准差,优化过程再次变得稳定并迅速收敛到一个好的解决方案。
有关输入归一化和权重初始化的更多详细信息,您可以阅读He et al Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification (ICCV 2015) 中的第 2.2 节。
如果我无法规范化数据怎么办?
如果由于某种原因,您无法提前计算均值和标准数据,您仍然可以nn.BatchNorm1d
在训练过程中使用它来估计和规范化数据。例如
class Model(nn.Module):
def __init__(self, input_size, H1, output_size):
super().__init__()
self.bn = nn.BatchNorm1d(input_size) # adding batchnorm
self.linear = nn.Linear(input_size, H1)
self.linear2 = nn.Linear(H1, output_size)
def forward(self, x):
x = torch.sigmoid(self.linear(self.bn(x))) # batchnorm the input x
x = torch.sigmoid(self.linear2(x))
return x
这种修改没有对输入数据进行任何更改,仅在 1000 个 epoch 之后产生类似的收敛:
一个小评论
为了数值稳定性,最好使用nn.BCEWithLogitsLoss
代替nn.BCELoss
. 为此,您需要torch.sigmoid
从forward()
输出中删除 ,sigmoid
将在损失内计算。
例如,参见this thread关于二进制预测的相关 sigmoid + 交叉熵损失。
推荐阅读
- python - app.teardown_appcontext() 在 Flask 中向应用实例注册函数时的调用顺序
- http - HTTP 服务器发送不完整到 http clinet recv
- spring - 在 Kubernetes 中运行的 Spring Batch
- python - VSCode 调试器 FileNotFoundError
- python - subprocess.check_output 追加 \n 即使字符串不包含它
- json - 解析json并放入数组
- python - 如何在 plotly Dash 中使用两个 DatePickerRange 创建线图?
- python - 如何通过包含列描述的 Python 客户端创建 BigQuery 视图
- node.js - nodejs - 对象数组获取信息并为每个对象添加新对象
- ruby-on-rails - 参数数量错误(给定 2,预期 0..1)