首页 > 技术文章 > pytorch——linear model2

xinrui-wang 2022-02-04 10:58 原文

#模型x*W+b,三维图象横坐标是w,纵坐标是b,竖坐标是损失函数
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from modulefinder import *
from mpl_toolkits.mplot3d import Axes3D
x_data=[1,2,3]
y_data=[2,4,6]
def forward(x,b):
return x*w+b
def loss(x,y):
y_pred = forward(x,b)
return (y_pred - y) * (y_pred - y)
w_list=[] #随机w
mse_list=[] #mean square error=每个w对应的损失函数
for w in np.arange(0,4,0.1):
for b in np.arange(-2.0,2.0,0.1):
print('w=',w)
print('b=',b)
l_sum=0
for x_val,y_val in zip(x_data,y_data):#x_datay_datazip拼成x_val y_val
y_pred_val=forward(x_val,b) #y
loss_val=loss(x_val,y_val) #预测值y^和真实值y之间的平方差,损失函数
l_sum+=loss_val #求每个样本损失函数之和
print('x=',x_val,'y=',y_val,'y^=',y_pred_val,'每个样本的损失函数:',loss_val)

print('dataset数据集的平均损失函数mse:', l_sum / 3)
w_list.append(w)#w[]列表追加元素w
mse_list.append(l_sum / 3)#mse[]列表追加元素新的平均损失函数
fig=plt.figure()
ax=Axes3D(fig)
ax.plot_surface(w, b, mse_list,rstride=1,cstride=1, cmap=plt.get_cmap('rainbow'))
plt.xlabel(r'w',fontsize=20,color='cyan')
plt.ylabel(r'b',fontsize=20,color='cyan')
ax.plot_surface(w,b=1,mse_list,)
plt.show()

推荐阅读