python - scikit-learn - LinearRegression() 可以使用一个特征学习与直线不同的东西吗?
问题描述
我正在使用 scikit-learn 的 LinearRegression() 和时间序列数据,例如
time_in_s value
1539015300000000000 2.061695
1539016200000000000 40.178125
1539017100000000000 12.276094
...
因为它是一个单变量情况,我希望我的模型是一条直线,如y=m*x+c。当我做
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(df.time_in_s,
df.value,
test_size=0.3,
random_state=0,
shuffle=False)
regressor = LinearRegression().fit(X_train, y_train)
y_pred_train = regressor.predict(X_train)
y_pred_test = regressor.predict(X_test)
[...]
正如预期的那样,我得到了一条直线:。如果我使用shuffle=True
,我会得到一条曲线。
我正在努力理解shuffle
这里的作用以及为什么我可以学到与具有一个特征的直线不同的东西。我会很感激一个提示。
编辑:这是模型的属性
>>> #shuffle=False
>>> print(f"{regressor.coef_}")
[-1.6e-16]
>>> print(f"{regressor.intercept_}")
272.0575589244862
>>> #shuffle=True
>>> print(f"{regressor.coef_}")
[-7.76e-17]
>>> print(f"{regressor.intercept_}")
143.9711420915541
对于绘图:
start = 61000
stop = 61500
fig, ax1 = plt.subplots(figsize=(15, 5))
color='tab:red'
plt.plot(df.index[start:train_length].values.reshape(-1, 1),
df.value[start:train_length].values.reshape(-1, 1),
color=color)
color='tab:blue'
plt.plot(df.index[train_length:stop].values.reshape(-1, 1),
df.value[train_length:stop].values.reshape(-1, 1),
color=color)
color='tab:green'
plt.plot(df.index[start:train_length].values.reshape(-1, 1),
y_pred_train[start:],
color=color,
linestyle='dashed')
plt.plot(df.index[train_length:stop].values.reshape(-1, 1),
y_pred_test[:stop - train_length],
color=color,
linestyle='dashed')
ax1.tick_params(axis='y')
ax1.tick_params(axis='x')
解决方案
你没有得到曲线。如果您查看train_test_split的帮助页面,它会写道:
shuffle bool, default=True 拆分前是否对数据进行shuffle。如果 shuffle=False 则分层必须为 None。
我假设您的数据是根据 排序的df.time_in_s
,因此如果您不洗牌,您将对数据的前 70% 运行回归模型并预测最后 30%。
使用shuffle=True
时,行的顺序被交换,您将随机抽取 70% 的数据并预测另外 30% 的数据,而不考虑时间。您没有显示绘图代码,但我的猜测是您以有序的时间绘制了原始数据框,并且只是将预测放在了顶部,因此您得到了这条模糊线。
推荐阅读
- android - 当我只选择伦敦位置时,我在谷歌地图上崩溃了
- html - 将 CSS 伪元素添加到 SVG
- java - 设置 Spring Security Kerberos - spn 和 keytab
- python-3.x - tf.reshape() 没有给出预期的结果
- qt - Qml font.styleName 在 Windows 上没有改变
- apache-spark - SparkSQL - 一些分区出现在 HiveServer2 但不是 SparkSQL
- javascript - Java将JSP页面保存在缓存中
- docker - 是否可以结合“docker logs”输出和“docker exec”输出?
- python - 将图像转换为 numpy 数组,将其保存到 Excel 中并反转所有
- visual-studio-code - Visual Studio 代码中的 Powershell 在哪里?