首页 > 解决方案 > PyTorch 回归产生与预测相同的数字

问题描述

我对传统的波士顿房价数据集应用了一个简单的 NN 回归。我遇到的问题是,当我使用经过训练的模型进行预测时,它总是预测相同的数字。这是我的代码:

import numpy as  np
import pandas as pd
from sklearn import datasets
data = datasets.load_boston()

X = pd.DataFrame(data.data, columns=data.feature_names)
Y = pd.DataFrame(data.target, columns=["MEDV"])

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.20, random_state=1234)

import torch

x_train = torch.tensor(X_train.values, dtype=torch.float)
y_train = torch.tensor(y_train.values, dtype=torch.float)
x_test = torch.tensor(X_test.values, dtype=torch.float)
y_test = torch.tensor(y_test.values, dtype=torch.float)

import torch.nn.functional as F

class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.predict = torch.nn.Linear(n_hidden, n_output)

    def forward(self, x):
        x = F.relu(self.hidden(x))
        x = self.predict(x)
        return x

net = Net(n_feature=13, n_hidden=50, n_output=1)
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss() 


for t in range(200):
    prediction = net(x_train)

    loss = loss_func(prediction, y_train)
    print(t, loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

这是每个纪元日志的损失:

0 6414.029296875
1 3.883837028532696e+19
2 5.963952279684907e+18
3 2.1470207536047063e+18
4 7.729279495852524e+17
5 2.7825424945486234e+17
6 1.001715136546734e+17
7 3.606170935335322e+16
8 1.2982211888283648e+16
9 4673591211720704.0
10 1682493577101312.0
11 605696890503168.0
12 218050926215168.0
13 78498328739840.0
14 28259370663936.0
15 10173364043776.0
16 3662410940416.0
17 1318467796992.0
18 474648543232.0
19 170873470976.0
20 61514457088.0
21 22145208320.0
22 7972275200.0
23 2870019072.0
24 1033206912.0
25 371954528.0
26 133903616.0
27 48205360.0
28 17353982.0
29 6247483.5
30 2249145.25
31 809743.25
32 291558.65625
33 105012.1171875
34 37855.3984375
35 13678.982421875
36 4975.46826171875
37 1842.20361328125
38 714.2283325195312
39 308.15716552734375
40 161.97158813476562
41 109.34477233886719
42 90.39911651611328
43 83.57867431640625
44 81.12332153320312
45 80.23938751220703
46 79.92117309570312
47 79.8066177368164
48 79.76537322998047
49 79.75053405761719
50 79.74518585205078
51 79.7432632446289
52 79.74256896972656
53 79.74231719970703
54 79.74223327636719
55 79.74219512939453
56 79.7421875
57 79.74217987060547
58 79.74217987060547
59 79.74217987060547
60 79.74217987060547
61 79.74217987060547
62 79.74217987060547
63 79.74217987060547
64 79.74217987060547
65 79.74217987060547
66 79.74217987060547
67 79.74217987060547
68 79.74217987060547
69 79.74217987060547
70 79.74217987060547
71 79.74217987060547
72 79.74217987060547
73 79.74217987060547
74 79.74217987060547
75 79.74217987060547
76 79.74217987060547
77 79.74217987060547
78 79.74217987060547
79 79.74217987060547
80 79.74217987060547
81 79.74217987060547
82 79.74217987060547
83 79.74217987060547
84 79.74217987060547
85 79.74217987060547
86 79.74217987060547
87 79.74217987060547
88 79.74217987060547
89 79.74217987060547
90 79.74217987060547
91 79.74217987060547
92 79.74217987060547
93 79.74217987060547
94 79.74217987060547
95 79.74217987060547
96 79.74217987060547
97 79.74217987060547
98 79.74217987060547
99 79.74217987060547
100 79.74217987060547
101 79.74217987060547
102 79.74217987060547
103 79.74217987060547
104 79.74217987060547
105 79.74217987060547
106 79.74217987060547
107 79.74217987060547
108 79.74217987060547
109 79.74217987060547
110 79.74217987060547
111 79.74217987060547
112 79.74217987060547
113 79.74217987060547
114 79.74217987060547
115 79.74217987060547
116 79.74217987060547
117 79.74217987060547
118 79.74217987060547
119 79.74217987060547
120 79.74217987060547
121 79.74217987060547
122 79.74217987060547
123 79.74217987060547
124 79.74217987060547
125 79.74217987060547
126 79.74217987060547
127 79.74217987060547
128 79.74217987060547
129 79.74217987060547
130 79.74217987060547
131 79.74217987060547
132 79.74217987060547
133 79.74217987060547
134 79.74217987060547
135 79.74217987060547
136 79.74217987060547
137 79.74217987060547
138 79.74217987060547
139 79.74217987060547
140 79.74217987060547
141 79.74217987060547
142 79.74217987060547
143 79.74217987060547
144 79.74217987060547
145 79.74217987060547
146 79.74217987060547
147 79.74217987060547
148 79.74217987060547
149 79.74217987060547
150 79.74217987060547
151 79.74217987060547
152 79.74217987060547
153 79.74217987060547
154 79.74217987060547
155 79.74217987060547
156 79.74217987060547
157 79.74217987060547
158 79.74217987060547
159 79.74217987060547
160 79.74217987060547
161 79.74217987060547
162 79.74217987060547
163 79.74217987060547
164 79.74217987060547
165 79.74217987060547
166 79.74217987060547
167 79.74217987060547
168 79.74217987060547
169 79.74217987060547
170 79.74217987060547
171 79.74217987060547
172 79.74217987060547
173 79.74217987060547
174 79.74217987060547
175 79.74217987060547
176 79.74217987060547
177 79.74217987060547
178 79.74217987060547
179 79.74217987060547
180 79.74217987060547
181 79.74217987060547
182 79.74217987060547
183 79.74217987060547
184 79.74217987060547
185 79.74217987060547
186 79.74217987060547
187 79.74217987060547
188 79.74217987060547
189 79.74217987060547
190 79.74217987060547
191 79.74217987060547
192 79.74217987060547
193 79.74217987060547
194 79.74217987060547
195 79.74217987060547
196 79.74217987060547
197 79.74217987060547
198 79.74217987060547
199 79.74217987060547

训练模型后,我使用以下代码进行预测:

with torch.no_grad():
    y_val = net(x_test)

当我打印预测时,我得到以下结果:

tensor([[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099],
[22.4099]])

标签: pythonmachine-learningdeep-learningneural-networkpytorch

解决方案


在我看来,该模型不太适合您的数据集。检查您的代码后,我想我可以猜到哪里出了问题。

你执行梯度下降的方式对我来说是不正确的。请记住,我们正在优化非凸函数。因此,将整个训练数据集打包成一个批次是行不通的,并且您的模型将卡在本地最小值中,这总是不够好。这可能会变成一个更复杂的讨论,需要很长的解释。我为您找到了一个很好的参考,以了解为什么我们不能对Here的非凸函数执行此操作。

我的建议是尝试将您的训练数据采样成小批量,并在您的 epoch 循环中使用小批量运行前向-后向更新。例如,从 8 到 96,看看它是如何工作的。在您的玩具示例中,您可以在每个 epoch 打乱您的训练数据并一个一个地选择您的小批量。如果您想更高级(或者,在 Pytorch 中执行深度学习的标准更高),您可以编写一个 PyTorch 数据集来处理数据加载和批处理。如果您正确执行此操作,则训练数据的损失应该非常小。

编辑:

您可能还希望逐渐降低学习率,例如,每 60 个 epoch 后降低 10 次,以获得更好的优化损失函数。


推荐阅读