首页 > 解决方案 > 运行 matplotlib 绘图后 Scikit-learn 线性模型拟合返回值错误

问题描述

我正在运行Aurélien Géron 使用 Scikit-Learn 和 TensorFlow 进行机器学习的第一章中的代码。

我试图运行的代码是:

# Code example
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.linear_model

# Load the data
oecd_bli = pd.read_csv(datapath + "oecd_bli_2015.csv", thousands=',')
gdp_per_capita = pd.read_csv(datapath + "gdp_per_capita.csv",thousands=',',delimiter='\t',
                             encoding='latin1', na_values="n/a")

# Prepare the data
country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)
X = np.c_[country_stats["GDP per capita"]]
y = np.c_[country_stats["Life satisfaction"]]

# Visualize the data
country_stats.plot(kind='scatter', x="GDP per capita", y='Life satisfaction')
plt.show()

# Select a linear model
model = sklearn.linear_model.LinearRegression()

# Train the model
model.fit(X, y)

model.fit(X, y)它在以下回溯步骤中失败:

ValueError                                Traceback (most recent call last)
 in 
     23 
     24 # # Train the model
---> 25 model.fit(X, y)
     26 
     27 # # Make a prediction for Cyprus

~\AppData\Local\Programs\Python\venv\ds\lib\site-packages\sklearn\linear_model\_base.py in fit(self, X, y, sample_weight)
    531         else:
    532             self.coef_, self._residues, self.rank_, self.singular_ = \
--> 533                 linalg.lstsq(X, y)
    534             self.coef_ = self.coef_.T
    535 

~\AppData\Local\Programs\Python\venv\ds\lib\site-packages\scipy\linalg\basic.py in lstsq(a, b, cond, overwrite_a, overwrite_b, check_finite, lapack_driver)
   1223             raise LinAlgError("SVD did not converge in Linear Least Squares")
   1224         if info < 0:
-> 1225             raise ValueError('illegal value in %d-th argument of internal %s'
   1226                              % (-info, lapack_driver))
   1227         resids = np.asarray([], dtype=x.dtype)

ValueError: illegal value in 4-th argument of internal None

但是,当我在没有命令的情况下重新运行 fit 函数时plt.show(),它可以正常工作:

country_stats.plot(kind='scatter', x="GDP per capita", y='Life satisfaction')

model.fit(X, y) # works OK

# # Make a prediction for Cyprus
X_new = [[22587]]  # Cyprus' GDP per capita
print(model.predict(X_new)) # outputs [[ 5.96242338]]

行为非常奇怪。不知道是不是因为我的包版本。这是我当前的软件包版本:

pip freeze | grep -E "numpy|pandas|scipy|matplotlib|sci"
matplotlib==3.2.1
numpy==1.18.4
pandas==0.25.3
scikit-image==0.16.2
scikit-learn==0.22
scipy==1.4.1

标签: pythonmatplotlibscikit-learn

解决方案


我已经运行了 10 次代码,它已经成功完成。您的代码中似乎遗漏了某些内容。完整代码,中断代码部分的 10 次试验,打印结果。

# Common imports
import numpy as np
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.linear_model

# to make this notebook's output stable across runs
np.random.seed(42)

# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

def prepare_country_stats(oecd_bli, gdp_per_capita):
    oecd_bli = oecd_bli[oecd_bli["INEQUALITY"]=="TOT"]
    oecd_bli = oecd_bli.pivot(index="Country", columns="Indicator", values="Value")
    gdp_per_capita.rename(columns={"2015": "GDP per capita"}, inplace=True)
    gdp_per_capita.set_index("Country", inplace=True)
    full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,
                                  left_index=True, right_index=True)
    full_country_stats.sort_values(by="GDP per capita", inplace=True)
    remove_indices = [0, 1, 6, 8, 33, 34, 35]
    keep_indices = list(set(range(36)) - set(remove_indices))
    return full_country_stats[["GDP per capita", 'Life satisfaction']].iloc[keep_indices]


# Load the data
oecd_bli = pd.read_csv("oecd_bli_2015.csv", thousands=',')
gdp_per_capita = pd.read_csv("gdp_per_capita.csv",thousands=',',delimiter='\t',
                             encoding='latin1', na_values="n/a")


oecd_bli.head(3)
#  LOCATION    Country INDICATOR  ... Value Flag Codes            Flags
#0      AUS  Australia   HO_BASE  ...   1.1          E  Estimated value
#1      AUT    Austria   HO_BASE  ...   1.0        NaN              NaN
#2      BEL    Belgium   HO_BASE  ...   2.0        NaN              NaN


gdp_per_capita.head(3)
#                                            Subject Descriptor  ... #Estimates Start After
#Country                                                         ...
#Afghanistan  Gross domestic product per capita, current prices  ...                #2013.0
#Albania      Gross domestic product per capita, current prices  ...                #2010.0
#Algeria      Gross domestic product per capita, current prices  ...                #2014.0


# Prepare the data
country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)
X = np.c_[country_stats["GDP per capita"]]
y = np.c_[country_stats["Life satisfaction"]]

X[0:3]
#array([[ 9054.914],
#       [ 9437.372],
#       [12239.894]])

y[0:3]
#array([[6. ],
#       [5.6],
#       [4.9]])

results = list()
for i in range(10):
    # Visualize the data
    country_stats.plot(kind='scatter', x="GDP per capita", y='Life satisfaction')
    plt.show()

    # Select a linear model
    model = sklearn.linear_model.LinearRegression()

    # Train the model
    model.fit(X, y)

    # Make a prediction for Cyprus
    X_new = [[22587]]  # Cyprus' GDP per capita
    results.append(model.predict(X_new))


print(results)
#[array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]])]

和:

pip freeze | grep -E "numpy|pandas|scipy|matplotlib|sci"
matplotlib==3.1.2
numpy==1.17.4
pandas==0.25.3
pandas-flavor==0.2.0
scikit-learn==0.22.1
scikit-plot==0.3.7
scipy==1.4.1

推荐阅读