scikit-learn - 使用 scikit-learn LinearRegression.predict() 遇到维度问题
问题描述
我正在对一系列维度进行多项式回归训练,并尝试将 predict() 用于输入列表。
inputs = np.linspace(0,10,100).reshape(-1,1)
for i, deg in enumerate([1, 3, 6, 9]):
poly = PolynomialFeatures(degree=deg)
X_poly = poly.fit_transform(X_train.reshape(-1,1))
linreg = LinearRegression().fit(X_poly, y_train)
print(linreg.predict(inputs))
当我调用 predict() 时,我得到以下回溯:
ValueError Traceback (most recent call last)
<ipython-input-5-4100ae3f3ba3> in <module>()
13 return
14
---> 15 answer_one()
<ipython-input-5-4100ae3f3ba3> in answer_one()
9 X_poly = PolynomialFeatures(degree=deg).fit_transform(X_train.reshape(-1,1))
10 linreg = LinearRegression().fit(X_poly, y_train)
---> 11 print(linreg.predict(inputs))
12 # print(linreg.score(X_poly, y_train))
13 return
/opt/conda/lib/python3.6/site-packages/sklearn/linear_model/base.py in predict(self, X)
266 Returns predicted values.
267 """
--> 268 return self._decision_function(X)
269
270 _preprocess_data = staticmethod(_preprocess_data)
/opt/conda/lib/python3.6/site-packages/sklearn/linear_model/base.py in _decision_function(self, X)
251 X = check_array(X, accept_sparse=['csr', 'csc', 'coo'])
252 return safe_sparse_dot(X, self.coef_.T,
--> 253 dense_output=True) + self.intercept_
254
255 def predict(self, X):
/opt/conda/lib/python3.6/site-packages/sklearn/utils/extmath.py in safe_sparse_dot(a, b, dense_output)
187 return ret
188 else:
--> 189 return fast_dot(a, b)
190
191
ValueError: shapes (100,1) and (2,) not aligned: 1 (dim 1) != 2 (dim 0)
(100,1) 形状显然适用于输入数组,但我不确定对象的形状是 (2,)。
解决方案
当您使用 poly 训练分类器时:
X_poly = poly.fit_transform(X_train.reshape(-1,1))
您需要确保预测也使用多边形值:
print(linreg.predict(inputs))
在这种情况下,输入也必须是多边形:
inputs = poly.transform(inputs)
print(linreg.predict(inputs))
推荐阅读
- html - PrimeNG p-orderList 禁用多项选择
- ios - 如何将项目从初始(一次)第一次视图控制器传递到主视图控制器并使用核心数据保存该数据
- php - CakePHP 3.6.13 和 PHP 7.2.5:堆栈跟踪未显示
- c++11 - 从字符串转换为双精度无法按预期转换
- c - 如何解决对“AES_ctr128_encrypt”的未定义引用
- laravel - 使用 PhpPresentation 读取和替换文本
- computer-science - 十进制数字自然数的语义
- php - PHP improve data retrieving
- angular - 角度生产构建 FatalProcessOutOfMemory 错误
- chartist.js - Chartist Pie Char 自定义标签