首页 > 解决方案 > 使用 scikit-learn LinearRegression.predict() 遇到维度问题

问题描述

我正在对一系列维度进行多项式回归训练,并尝试将 predict() 用于输入列表。

inputs = np.linspace(0,10,100).reshape(-1,1)

for i, deg in enumerate([1, 3, 6, 9]):
        poly = PolynomialFeatures(degree=deg)
        X_poly = poly.fit_transform(X_train.reshape(-1,1))
        linreg = LinearRegression().fit(X_poly, y_train)
        print(linreg.predict(inputs))

当我调用 predict() 时,我得到以下回溯:

ValueError                                Traceback (most recent call last)
<ipython-input-5-4100ae3f3ba3> in <module>()
     13     return
     14 
---> 15 answer_one()

<ipython-input-5-4100ae3f3ba3> in answer_one()
      9         X_poly = PolynomialFeatures(degree=deg).fit_transform(X_train.reshape(-1,1))
     10         linreg = LinearRegression().fit(X_poly, y_train)
---> 11         print(linreg.predict(inputs))
     12         # print(linreg.score(X_poly, y_train))
     13     return

/opt/conda/lib/python3.6/site-packages/sklearn/linear_model/base.py in predict(self, X)
    266             Returns predicted values.
    267         """
--> 268         return self._decision_function(X)
    269 
    270     _preprocess_data = staticmethod(_preprocess_data)

/opt/conda/lib/python3.6/site-packages/sklearn/linear_model/base.py in _decision_function(self, X)
    251         X = check_array(X, accept_sparse=['csr', 'csc', 'coo'])
    252         return safe_sparse_dot(X, self.coef_.T,
--> 253                                dense_output=True) + self.intercept_
    254 
    255     def predict(self, X):

/opt/conda/lib/python3.6/site-packages/sklearn/utils/extmath.py in safe_sparse_dot(a, b, dense_output)
    187         return ret
    188     else:
--> 189         return fast_dot(a, b)
    190 
    191 

ValueError: shapes (100,1) and (2,) not aligned: 1 (dim 1) != 2 (dim 0)

(100,1) 形状显然适用于输入数组,但我不确定对象的形状是 (2,)。

标签: scikit-learn

解决方案


当您使用 poly 训练分类器时:

X_poly = poly.fit_transform(X_train.reshape(-1,1))

您需要确保预测也使用多边形值: print(linreg.predict(inputs)) 在这种情况下,输入也必须是多边形:

inputs = poly.transform(inputs)
print(linreg.predict(inputs))

推荐阅读