data-visualization - 使用 Statsmodels 进行简单逻辑回归:添加截距并可视化逻辑回归方程
问题描述
使用 Statsmodels,我正在尝试生成一个简单的逻辑回归模型,以根据身高 (Hgt) 预测一个人是否吸烟 (Smoke)。
我感觉需要将截距包含在逻辑回归模型中,但我不确定如何使用 add_constant() 函数来实现截距。另外,我不确定为什么会产生以下错误。
这是数据集,Pulse.CSV:https ://drive.google.com/file/d/1FdUK9p4Dub4NXsc-zHrYI-AGEEBkX98V/view?usp=sharing
完整的代码和输出在这个 PDF 文件中:https ://drive.google.com/file/d/1kHlrAjiU7QvFXF2a7tlTSFPgfpq9bOXJ/view?usp=sharing
import numpy as np
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
raw_data = pd.read_csv('Pulse.csv')
raw_data
x1 = raw_data['Hgt']
y = raw_data['Smoke']
reg_log = sm.Logit(y,x1,missing='Drop')
results_log = reg_log.fit()
def f(x,b0,b1):
return np.array(np.exp(b0+x*b1) / (1 + np.exp(b0+x*b1)))
f_sorted = np.sort(f(x1,results_log.params[0],results_log.params[1]))
x_sorted = np.sort(np.array(x1))
plt.scatter(x1,y,color='C0')
plt.xlabel('Hgt', fontsize = 20)
plt.ylabel('Smoked', fontsize = 20)
plt.plot(x_sorted,f_sorted,color='C8')
plt.show()
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
~/opt/anaconda3/lib/python3.7/site-packages/pandas/core/indexes/base.py in get_value(self, series, key)
4729 try:
-> 4730 return self._engine.get_value(s, k, tz=getattr(series.dtype, "tz", None))
4731 except KeyError as e1:
((( Truncated for brevity )))
IndexError: index out of bounds
解决方案
Statsmodels回归中默认不添加拦截,但如果需要,您可以手动包含它。
import numpy as np
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
raw_data = pd.read_csv('Pulse.csv')
raw_data
x1 = raw_data['Hgt']
y = raw_data['Smoke']
x1 = sm.add_constant(x1)
reg_log = sm.Logit(y,x1,missing='Drop')
results_log = reg_log.fit()
results_log.summary()
def f(x,b0,b1):
return np.array(np.exp(b0+x*b1) / (1 + np.exp(b0+x*b1)))
f_sorted = np.sort(f(x1,results_log.params[0],results_log.params[1]))
x_sorted = np.sort(np.array(x1))
plt.scatter(x1['Hgt'],y,color='C0')
plt.xlabel('Hgt', fontsize = 20)
plt.ylabel('Smoked', fontsize = 20)
plt.plot(x_sorted,f_sorted,color='C8')
plt.show()
这也将解决错误,因为您的初始代码中没有拦截。资源
推荐阅读
- java - 无法从 onClickListener 添加到 arraylist
- java - 使用 java 创建 zip 并使其可下载
- vba - 邮件合并重命名文件为word文档中的值
- android - Moshi 将嵌套的 JSON 值映射到字段
- parsing - 使用 OCaml Menhir,有没有办法在处理之前访问某些内容?
- list - 实现列表的更快方法?
- android - 实现 AndroidX SeekBarPreference 时的问题
- swift - 尝试使用 CodingKeys 解码时出错
- wordpress - 未登录的用户看不到商店,必须登录,成功登录的用户必须重定向到商店
- mongodb - 如何使用 Golang 比较两个 bson.M 数据集