运行 matplotlib 绘图后的 Scikit-learn 线性模型拟合返回值错误 [英] Scikit-learn linear model fit after running matplotlib plot returns Value error

查看:84
本文介绍了运行 matplotlib 绘图后的 Scikit-learn 线性模型拟合返回值错误的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在运行 第一章中的代码Aurélien Géron 使用 Scikit-Learn 和 TensorFlow 进行机器学习实践.

I am running the code from first chapter of Aurélien Géron's Hands-on Machine Learning with Scikit-Learn and TensorFlow.

我试图运行的代码是:

# Code example
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.linear_model

# Load the data
oecd_bli = pd.read_csv(datapath + "oecd_bli_2015.csv", thousands=',')
gdp_per_capita = pd.read_csv(datapath + "gdp_per_capita.csv",thousands=',',delimiter='\t',
                             encoding='latin1', na_values="n/a")

# Prepare the data
country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)
X = np.c_[country_stats["GDP per capita"]]
y = np.c_[country_stats["Life satisfaction"]]

# Visualize the data
country_stats.plot(kind='scatter', x="GDP per capita", y='Life satisfaction')
plt.show()

# Select a linear model
model = sklearn.linear_model.LinearRegression()

# Train the model
model.fit(X, y)

它在步骤 model.fit(X, y) 失败,回溯如下:

It fails at the step model.fit(X, y) with the below traceback:

ValueError                                Traceback (most recent call last)
 in 
     23 
     24 # # Train the model
---> 25 model.fit(X, y)
     26 
     27 # # Make a prediction for Cyprus

~\AppData\Local\Programs\Python\venv\ds\lib\site-packages\sklearn\linear_model\_base.py in fit(self, X, y, sample_weight)
    531         else:
    532             self.coef_, self._residues, self.rank_, self.singular_ = \
--> 533                 linalg.lstsq(X, y)
    534             self.coef_ = self.coef_.T
    535 

~\AppData\Local\Programs\Python\venv\ds\lib\site-packages\scipy\linalg\basic.py in lstsq(a, b, cond, overwrite_a, overwrite_b, check_finite, lapack_driver)
   1223             raise LinAlgError("SVD did not converge in Linear Least Squares")
   1224         if info < 0:
-> 1225             raise ValueError('illegal value in %d-th argument of internal %s'
   1226                              % (-info, lapack_driver))
   1227         resids = np.asarray([], dtype=x.dtype)

ValueError: illegal value in 4-th argument of internal None

但是,当我在没有 plt.show() 命令的情况下重新运行 fit 函数时,它工作正常:

However, when I re-run the fit function without the plt.show() command, it works fine:

country_stats.plot(kind='scatter', x="GDP per capita", y='Life satisfaction')

model.fit(X, y) # works OK

# # Make a prediction for Cyprus
X_new = [[22587]]  # Cyprus' GDP per capita
print(model.predict(X_new)) # outputs [[ 5.96242338]]

这种行为非常奇怪.不确定是否是由于我的软件包版本.这是我当前的软件包版本:

The behavior is super odd. Not sure if it is due to my package versions. Here are my current package versions:

pip freeze | grep -E "numpy|pandas|scipy|matplotlib|sci"
matplotlib==3.2.1
numpy==1.18.4
pandas==0.25.3
scikit-image==0.16.2
scikit-learn==0.22
scipy==1.4.1

推荐答案

我已经运行了 10 次代码,它已经成功完成.您的代码中似乎遗漏了一些内容.完整的代码,10 次测试出错的代码部分,打印结果.

I have run the code 10 times and it has finished successfully. It seems that you have missed something in your code. Full code, 10 trials of the code part that breaks, results are printed.

# Common imports
import numpy as np
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.linear_model

# to make this notebook's output stable across runs
np.random.seed(42)

# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

def prepare_country_stats(oecd_bli, gdp_per_capita):
    oecd_bli = oecd_bli[oecd_bli["INEQUALITY"]=="TOT"]
    oecd_bli = oecd_bli.pivot(index="Country", columns="Indicator", values="Value")
    gdp_per_capita.rename(columns={"2015": "GDP per capita"}, inplace=True)
    gdp_per_capita.set_index("Country", inplace=True)
    full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,
                                  left_index=True, right_index=True)
    full_country_stats.sort_values(by="GDP per capita", inplace=True)
    remove_indices = [0, 1, 6, 8, 33, 34, 35]
    keep_indices = list(set(range(36)) - set(remove_indices))
    return full_country_stats[["GDP per capita", 'Life satisfaction']].iloc[keep_indices]


# Load the data
oecd_bli = pd.read_csv("oecd_bli_2015.csv", thousands=',')
gdp_per_capita = pd.read_csv("gdp_per_capita.csv",thousands=',',delimiter='\t',
                             encoding='latin1', na_values="n/a")


oecd_bli.head(3)
#  LOCATION    Country INDICATOR  ... Value Flag Codes            Flags
#0      AUS  Australia   HO_BASE  ...   1.1          E  Estimated value
#1      AUT    Austria   HO_BASE  ...   1.0        NaN              NaN
#2      BEL    Belgium   HO_BASE  ...   2.0        NaN              NaN


gdp_per_capita.head(3)
#                                            Subject Descriptor  ... #Estimates Start After
#Country                                                         ...
#Afghanistan  Gross domestic product per capita, current prices  ...                #2013.0
#Albania      Gross domestic product per capita, current prices  ...                #2010.0
#Algeria      Gross domestic product per capita, current prices  ...                #2014.0


# Prepare the data
country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)
X = np.c_[country_stats["GDP per capita"]]
y = np.c_[country_stats["Life satisfaction"]]

X[0:3]
#array([[ 9054.914],
#       [ 9437.372],
#       [12239.894]])

y[0:3]
#array([[6. ],
#       [5.6],
#       [4.9]])

results = list()
for i in range(10):
    # Visualize the data
    country_stats.plot(kind='scatter', x="GDP per capita", y='Life satisfaction')
    plt.show()

    # Select a linear model
    model = sklearn.linear_model.LinearRegression()

    # Train the model
    model.fit(X, y)

    # Make a prediction for Cyprus
    X_new = [[22587]]  # Cyprus' GDP per capita
    results.append(model.predict(X_new))


print(results)
#[array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]])]

还有:

pip freeze | grep -E "numpy|pandas|scipy|matplotlib|sci"
matplotlib==3.1.2
numpy==1.17.4
pandas==0.25.3
pandas-flavor==0.2.0
scikit-learn==0.22.1
scikit-plot==0.3.7
scipy==1.4.1

这篇关于运行 matplotlib 绘图后的 Scikit-learn 线性模型拟合返回值错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆