python - 为什么数据集的 SGD 损失与 pytorch 代码与用于线性回归的暂存 python 代码不匹配?
问题描述
我正在尝试在葡萄酒数据集上实现多元线性回归。但是当我将 Pytorch 的结果与 Python 的临时代码进行比较时,损失并不相同。
我的刮码:
功能:
def yinfer(X, beta):
return beta[0] + np.dot(X,beta[1:])
def cost(X, Y, beta):
sum = 0
m = len(Y)
for i in range(m):
sum = sum + ( yinfer(X[i],beta) - Y[i])*(yinfer(X[i],beta) - Y[i])
return sum/(1.0*m)
主要代码:
alpha = 0.005
b=[0,0.04086357 ,-0.02831656 ,0.09622949 ,-0.15162516 ,0.60188454 ,0.47528714,
-0.6066466 ,-0.22995654 ,-0.58388734 ,0.20954669 ,-0.67851365]
beta = np.array(b)
print(beta)
iterations = 1000
arr_cost = np.zeros((iterations,2))
m = len(Y)
temp_beta = np.zeros(12)
for i in range(iterations):
for k in range(m):
temp_beta[0] = yinfer(X[k,:], beta) - Y[k]
temp_beta[1:] = (yinfer(X[k,:], beta) - Y[k])*X[k,:]
beta = beta - alpha*temp_beta/(1.0*m) #(m*np.linalg.norm(temp_beta))
arr_cost[i] = [i,cost(X,Y,beta)]
#print(cost(X,Y,beta))
plt.scatter(arr_cost[0:iterations,0], arr_cost[0:iterations,1])
我使用了与 Pytorch 代码中相同的权重
我的 Pytorch 代码:
class LinearRegression(nn.Module):
def __init__(self,n_input_features):
super(LinearRegression,self).__init__()
self.linear=nn.Linear(n_input_features,1)
# self.linear.weight.data=b.view(1,-1)
self.linear.bias.data.fill_(0.0)
nn.init.xavier_uniform_(self.linear.weight)
# nn.init.xavier_normal_(self.linear.bias)
def forward(self,x):
y_predicted=self.linear(x)
return y_predicted
model=LinearRegression(11)
criterion = nn.MSELoss()
num_epochs=1000
for epoch in range(num_epochs):
for x,y in train_data:
y_pred=model(x)
loss=criterion(y,y_pred)
# print(loss)
loss.backward()
optimizer.step()
optimizer.zero_grad()
我的数据加载器:
class Data(Dataset):
def __init__(self):
self.x=x_train
self.y=y_train
self.len=self.x.shape[0]
def __getitem__(self,index):
return self.x[index],self.y[index]
def __len__(self):
return self.len
dataset=Data()
train_data=DataLoader(dataset=dataset,batch_size=1,shuffle=False)
有人可以告诉我为什么会发生这种情况,或者我的代码中是否有任何错误?
解决方案
推荐阅读
- flutter - 使用循环进度指示器时底部溢出
- c# - 访问 Azure 密钥保管库时出现间歇性错误 - 密钥集不存在
- mysql - 如何在sql查询中使用动态限制
- sapui5 - 防止 sap.m.popover 在点击时关闭
- r - 从环境中移除所有对象,除了那些匹配给定模式的对象
- c++ - 可变参数模板仅在前向声明时编译
- powerquery - PowerQuery 将记录列表转换为分隔字符串
- java - 即使使用了必需的属性,也不会检测到重复的 XML 标记
- mongodb - 有什么方法可以在不使用副本集的情况下在 Mongodb4.0 中执行 ACID 事务
- angular - 从当前路线导航到另一条路线时清除本地存储数据