使用Statsmodels进行简单的逻辑回归:添加截距并可视化逻辑回归方程 [英] Simple logistic regression with Statsmodels: Adding an intercept and visualizing the logistic regression equation

查看:836
本文介绍了使用Statsmodels进行简单的逻辑回归:添加截距并可视化逻辑回归方程的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用Statsmodels生成一个简单的逻辑回归模型,以根据一个人的身高(Hgt)来预测一个人是否吸烟(吸烟)。

Using Statsmodels, I am trying to generate a simple logistic regression model to predict whether a person smokes or not (Smoke) based on their height (Hgt).

我感觉逻辑回归模型中需要包含拦截器,但是我不确定如何使用add_constant()函数实现拦截器。另外,我不确定为什么会产生以下错误。

I have a feeling that an intercept needs to be included into the logistic regression model but I am not sure how to implement one using the add_constant() function. Also, I am unsure why the error below is generated.

这是数据集Pulse.CSV: https://drive.google.com/file/d/1FdUK9p4Dub4NXsc-zHrYI-AGEEBkX98V/view?usp=sharing

This is the dataset, Pulse.CSV: https://drive.google.com/file/d/1FdUK9p4Dub4NXsc-zHrYI-AGEEBkX98V/view?usp=sharing

完整的代码和输出在以下PDF文件中: https://drive.google.com/file/d/1kHlrAjiU7QvFXF2a7tlTSFPgfpq9bOXJ/view?usp=sharing

The full code and output are in this PDF file: https://drive.google.com/file/d/1kHlrAjiU7QvFXF2a7tlTSFPgfpq9bOXJ/view?usp=sharing

import numpy as np
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
raw_data = pd.read_csv('Pulse.csv')
raw_data
x1 = raw_data['Hgt']
y = raw_data['Smoke'] 
reg_log = sm.Logit(y,x1,missing='Drop')
results_log = reg_log.fit()
def f(x,b0,b1):
    return np.array(np.exp(b0+x*b1) / (1 + np.exp(b0+x*b1)))
f_sorted = np.sort(f(x1,results_log.params[0],results_log.params[1]))
x_sorted = np.sort(np.array(x1))
plt.scatter(x1,y,color='C0')
plt.xlabel('Hgt', fontsize = 20)
plt.ylabel('Smoked', fontsize = 20)
plt.plot(x_sorted,f_sorted,color='C8')
plt.show()

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/opt/anaconda3/lib/python3.7/site-packages/pandas/core/indexes/base.py in get_value(self, series, key)
   4729         try:
-> 4730             return self._engine.get_value(s, k, tz=getattr(series.dtype, "tz", None))
   4731         except KeyError as e1:
((( Truncated for brevity )))
IndexError: index out of bounds


推荐答案

<在 Statsmodels 回归中,默认情况下不添加strong> Intercept ,但是如果需要,可以手动将其包括在内。

Intercept is not added by default in Statsmodels regression, but if you need you can include it manually.

import numpy as np
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
raw_data = pd.read_csv('Pulse.csv')
raw_data
x1 = raw_data['Hgt']
y = raw_data['Smoke'] 

x1 = sm.add_constant(x1)

reg_log = sm.Logit(y,x1,missing='Drop')
results_log = reg_log.fit()

results_log.summary()

def f(x,b0,b1):
    return np.array(np.exp(b0+x*b1) / (1 + np.exp(b0+x*b1)))
f_sorted = np.sort(f(x1,results_log.params[0],results_log.params[1]))
x_sorted = np.sort(np.array(x1))

plt.scatter(x1['Hgt'],y,color='C0')

plt.xlabel('Hgt', fontsize = 20)
plt.ylabel('Smoked', fontsize = 20)
plt.plot(x_sorted,f_sorted,color='C8')
plt.show()

此由于您的初始代码中没有拦截,因此也会解决该错误。获取一个参数?

This will also resolve the error as there was no intercept in your initial code.Source

这篇关于使用Statsmodels进行简单的逻辑回归:添加截距并可视化逻辑回归方程的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆