python - 预期二维数组,得到一维数组:array=[1 3 5 6 7 8 9]?
问题描述
x=[1,3,5,6,7,8,9]
y=[4,5,6,9,3,4,6]
def linear_model_main(X_parameters,Y_parameters,predict_value):
# Create linear regression object
regr = linear_model.LinearRegression()
regr.fit(x, y)
predict_outcome = regr.predict(predict_value)
predictions = {}
predictions['intercept'] = regr.intercept
predictions['coefficient'] = regr.coef
predictions['predicted_value'] = predict_outcome
predicted_value = predict_outcome
#return predicted_value
return predictions
predictvalue = 7000
result = linear_model_main(x,y,predictvalue)
print ("Intercept value " , result['intercept'])
print ("coefficient" , result['coefficient'])
print ("Predicted value: ",result['predicted_value'])
调用 fit 函数时出现此错误:regr.fit(x, y)
ValueError:预期的 2D 数组,得到 1D 数组:array=[1 3 5 6 7 8 9]。如果您的数据具有单个特征,则使用 array.reshape(-1, 1) 重塑您的数据,如果它包含单个样本,则使用 array.reshape(1, -1) 。
解决方案
这是您更正的代码:
from sklearn import linear_model
import numpy as np
x=[1,3,5,6,7,8,9]
y=[4,5,6,9,3,4,6]
def linear_model_main(X_parameters,Y_parameters,predict_value):
# Create linear regression object
regr = linear_model.LinearRegression()
regr.fit(np.array(x).reshape(-1,1), np.array(y).reshape(-1,1))
predict_outcome = regr.predict(np.array(predict_value).reshape(-1,1))
predictions = {}
predictions['intercept'] = regr.intercept_
predictions['coefficient'] = regr.coef_
predictions['predicted_value'] = predict_outcome
predicted_value = predict_outcome
#return predicted_value
return predictions
predictvalue = 7000
result = linear_model_main(x,y,predictvalue)
print ("Intercept value " , result['intercept'])
print ("coefficient" , result['coefficient'])
print ("Predicted value: ",result['predicted_value'])
你有几个错误,我将在下面解释:
1-首先,您需要将输入转换为 NumPy 数组,而不是 1 x n 数组,您需要 n x 1 数组。你得到的错误是因为这个(这就是 scikit-learn 模型的设计方式)。
2- 其次,您错过了属性名称末尾的下划线,例如“intercept_”
3- 预测值也应该是一个 n×1 数组。
编辑:这是情节的代码:
plt.scatter(x,y)
axes = plt.gca()
x_vals = np.array(axes.get_xlim())
y_vals = result['intercept'][0] + result['coefficient'][0] * x_vals
plt.plot(x_vals, y_vals, '--')
plt.show()
推荐阅读
- excel - 为要导出到 txt 的不同列插入另一个过滤器命令
- java - 为什么当多对一是惰性的时,Hibernate 会出现 StackOverflowError?
- flutter - zsh:找不到命令:颤动?
- azure-data-factory - ADF 复制任务字段类型 boolean 为小写
- jenkins - 詹金斯:用分支名称触发另一个工作
- java - 使用 Location API Android 时,服务的用途是什么?
- javascript - 使用星期几返回特定对象
- javascript - 筛选索引中的实体选择用户
- selenium - 如何在不可见元素上设置日期
- groovy - 通过在 jmeter 中的 JSR223 Sampler(groovy) 中运行代码来获取意外令牌