python-3.x - 与权重梯度相比,线性回归中的偏差梯度仍然很小,并且没有正确学习截距
问题描述
我拼凑了一个虚拟模型来展示 pytorch 中的线性回归,但我发现我的模型没有正确学习。它在学习斜率方面做得很好,但截距并没有真正让步。在每个时期打印出毕业生告诉我,事实上,毕业生的偏差要小得多。这是为什么?我该如何补救,以便正确学习拦截?
这就是发生的事情(设置为 0 来说明):
# Create some dummy data: we establish a linear relationship between x and y
a = np.random.rand()
b = np.random.rand()
a=0
x = np.linspace(start=0, stop=100, num=100)
y = a * x + b
# Now let's create some noisy measurements
noise = np.random.normal(size=100)
y_noisy = a * x + b + noise
# What's the overall error?
mse_actual = np.sum(np.power(y-y_noisy,2))/len(y)
# Visualize
plt.scatter(x,y_noisy, label='Measurements', alpha=.7)
plt.plot(x,y,'r', label='Underlying')
plt.legend()
plt.show()
# Let's learn something!
inputs = torch.from_numpy(x).type(torch.FloatTensor).unsqueeze(1)
targets = torch.from_numpy(y_noisy).type(torch.FloatTensor).unsqueeze(1)
# This is our model (one hidden node + bias)
model = torch.nn.Linear(1,1)
optimizer = torch.optim.SGD(model.parameters(),lr=1e-5)
loss_function = torch.nn.MSELoss()
# What does it predict right now?
shuffled_inputs, preds = [], []
for input, target in zip(inputs,targets):
pred = model(input)
shuffled_inputs.append(input.detach().numpy()[0])
preds.append(pred.detach().numpy()[0])
# Visualize
plt.scatter(x,y_noisy, color='blue', label='Measurements', alpha=.7)
plt.plot(shuffled_inputs, preds, color='orange', label='Predictions', alpha=.7)
plt.plot(x,y,'r', label='Underlying')
plt.legend()
plt.show()
# Let's train!
epochs = 100
a_s, b_s = [], []
for epoch in range(epochs):
# Reset optimizer values
optimizer.zero_grad()
# Predict values using current model
preds = model(inputs)
# How far off are we?
loss = loss_function(targets,preds)
# Calculate the gradient
loss.backward()
# Update model
optimizer.step()
for p in model.parameters():
print('Grads:', p.grad)
# New parameters
a_s.append(list(model.parameters())[0].item())
b_s.append(list(model.parameters())[1].item())
print(f"Epoch {epoch+1} -- loss = {loss}")
解决方案
这有点无法回答,但只需使用更多时期或添加更多数据点。当您有 100 个具有与您一样重要的噪声的数据点时(如果您只是绘制初始数据,它变得很明显)该模型将与 MSE 作斗争作为损失。
我看不到您的图像(工作受阻 imgur ...),但我发现如果您没有调整 matplotlib 图上的轴,它看起来很糟糕,因为它在 x 轴上被放大了(当 a=0 时),所以我也缩小了:
# Create some dummy data: we establish a linear relationship between x and y
a = np.random.rand()
b = np.random.rand()
a=0
N = 10000
x = np.linspace(start=0, stop=100, num=N)
y = a * x + b
# Now let's create some noisy measurements
noise = np.random.normal(size=N)*0.1
y_noisy = a * x + b + noise
# What's the overall error?
mse_actual = np.sum(np.power(y-y_noisy,2))/len(y)
# Visualize
plt.figure()
plt.scatter(x,y_noisy, label='Measurements', alpha=.7)
plt.plot(x,y,'r', label='Underlying')
plt.legend()
plt.show()
# Let's learn something!
inputs = torch.from_numpy(x).type(torch.FloatTensor).unsqueeze(1)
targets = torch.from_numpy(y_noisy).type(torch.FloatTensor).unsqueeze(1)
# This is our model (one hidden node + bias)
model = torch.nn.Linear(1,1)
optimizer = torch.optim.SGD(model.parameters(),lr=1e-5)
loss_function = torch.nn.MSELoss()
# Let's train!
epochs = 50000
a_s, b_s = [], []
for epoch in range(epochs):
# Reset optimizer values
optimizer.zero_grad()
# Predict values using current model
preds = model(inputs)
# How far off are we?
loss = loss_function(targets,preds)
# Calculate the gradient
loss.backward()
# Update model
optimizer.step()
#for p in model.parameters():
# print('Grads:', p.grad)
# New parameters
a_s.append(list(model.parameters())[0].item())
b_s.append(list(model.parameters())[1].item())
print(f"Epoch {epoch+1} -- loss = {loss}")
# What does it predict right now?
shuffled_inputs, preds = [], []
for input, target in zip(inputs,targets):
pred = model(input)
shuffled_inputs.append(input.detach().numpy()[0])
preds.append(pred.detach().numpy()[0])
plt.figure()
plt.scatter(x,y_noisy, color='blue', label='Measurements', alpha=.7)
plt.plot(shuffled_inputs, preds, color='orange', label='Predictions', alpha=.7)
plt.plot(x,y,'r', label='Underlying')
plt.axis([0,100,y.min()-1,y.max()+1])
plt.legend()
plt.show()
推荐阅读
- xml - 使用 Xquery 返回属性值
- ios - [iOS Swift4]如何支持 ViewControllers(模态和推送)
- android - Rest API 进入 onFailure() 方法问题 - retrofit2
- scala - 在条件成立时在 Scala 中链接多个函数
- python - *** 在 Python -3 中是什么意思?
- django - 使用请求的结果直接生成表单
- python-3.x - 如何比较两个不同列表的每个元素?
- activerecord - 如何在 Yii2 的主 Activeform 模型中处理新的关系模型记录?
- r - 如何通过组合 R 中的所有变量来修改这些 dplyr 代码以进行多元线性回归
- amcharts - amcharts4的exportmenu.draw无法读取null的属性推送如何解决?