python - Python/Sklearn 中的循环回归模型
问题描述
我试图在同一组输入/自变量上系统地回归几个不同的因变量(国家),并希望使用 Sklearn 在 Python 中以循环方式执行此操作。因变量如下所示:
Europe UK Japan USA Canada
Jan-10 10 13 39 42 16
Feb-10 13 16 48 51 19
Mar-10 15 18 54 57 21
Apr-10 12 15 45 48 18
May-10 11 14 42 45 17
而自变量看起来像这样:
Input 1 Input 2 Input 3 Input 4
Jan-10 90 50 3 41
Feb-10 95 54 5 43
Mar-10 92 52 1 45
Apr-10 91 60 1 49
May-10 90 67 11 49
我发现手动回归它们+一次存储一个预测很容易(即欧洲在所有四个输入上,然后是日本等),但还没有弄清楚如何编写一个可以一次性完成所有这些的单个循环函数。我怀疑我可能需要使用列表/字典来存储因变量并一个接一个地调用它们,但不太知道如何以 Pythonic 方式编写它。
单个循环的现有代码如下所示:
x = pd.DataFrame('countryinputs.csv')
countries = pd.DataFrame('countryoutputs.csv')
y = countries['Europe']
from sklearn.linear_model import LinearRegression
regressor = LinearRegression()
regressor.fit(X, y)
y_pred = regressor.predict(X)
解决方案
只需遍历列名。然后将名称传递给定义的函数。实际上,您可以将过程包装在字典理解中并传递给DataFrame
构造函数以返回预测值的数据框(与原始数据框的形状相同):
X = pd.DataFrame(...)
countries = pd.DataFrame(...)
def reg_proc(label):
y = countries[label]
regressor = LinearRegression()
regressor.fit(X, y)
y_pred = regressor.predict(X)
return(y_pred)
pred_df = pd.DataFrame({lab: reg_proc(lab) for lab in countries.columns},
columns = countries.columns)
使用随机的种子数据来证明以下工具将是您的国家/地区:
数据
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
np.random.seed(7172018)
tools = pd.DataFrame({'pandas': np.random.uniform(0,1000,50),
'r': np.random.uniform(0,1000,50),
'julia': np.random.uniform(0,1000,50),
'sas': np.random.uniform(0,1000,50),
'spss': np.random.uniform(0,1000,50),
'stata': np.random.uniform(0,1000,50)
},
columns=['pandas', 'r', 'julia', 'sas', 'spss', 'stata'])
X = pd.DataFrame({'Input1': np.random.randn(50)*10,
'Input2': np.random.randn(50)*10,
'Input3': np.random.randn(50)*10,
'Input4': np.random.randn(50)*10})
模型
def reg_proc(label):
y = tools[label]
regressor = LinearRegression()
regressor.fit(X, y)
y_pred = regressor.predict(X)
return(y_pred)
pred_df = pd.DataFrame({lab: reg_proc(lab) for lab in tools.columns},
columns = tools.columns)
print(pred_df.head(10))
# pandas r julia sas spss stata
# 0 547.631679 576.025733 682.390046 507.767567 246.020799 557.648181
# 1 577.334819 575.992992 280.579234 506.014191 443.044139 396.044620
# 2 430.494827 576.211105 541.096721 441.997575 386.309627 558.472179
# 3 440.662962 524.582054 406.849303 420.017656 508.701222 393.678200
# 4 588.993442 472.414081 453.815978 479.208183 389.744062 424.507541
# 5 520.215513 489.447248 670.708618 459.375294 314.008988 516.235188
# 6 515.266625 459.292370 477.485995 436.398180 446.777292 398.826234
# 7 423.930650 414.069751 629.444118 378.059735 448.760240 449.062734
# 8 549.769034 406.531405 653.557937 441.425445 348.725447 456.089921
# 9 396.826924 399.327683 717.285415 361.235709 444.830491 429.967976
推荐阅读
- django - 删除了一个文件夹,得到 ModuleNotFoundError: No module named "api"
- python - kivy buildozer 无法修补文件 setup.py
- python - 如何用 Pytorch 张量中的某个值替换每一行填充为零的行?
- r - RStudio 没有读取数据框并且连接卡住了?
- c++ - 我似乎无法在 C++ 中捕捉到这个异常
- amazon-web-services - 将 EC2 实例迁移到不同的子网
- python - 在 Plotly 中为图表注释添加自定义标签
- python - 如何实现一个非全连接层作为最后一层的神经网络?
- javascript - Vue.js 更新的属性不会在子组件上更改
- java - 尝试运行双重项目时出现 BindException