如何绘制线性回归? [英] How do I plot Linear Regression?

查看:177
本文介绍了如何绘制线性回归?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我希望学习如何将Plotly与Python结合使用以进行数据分析. 我一直在使用本网站作为参考.

I wish to learn how to use Plotly with Python for data analysis. I have been using this website as reference.

我当前的代码如下:

from plotly import tools
import plotly as py
import plotly.graph_objs as go

py.offline.init_notebook_mode(connected=True)

# Create linear regression object
regr = linear_model.LinearRegression()

# Train the model using the training sets
regr.fit(X_train, y_train)

p1 = go.Scatter(x=X_test, 
                y=y_test, 
                mode='markers',
                marker=dict(color='black')
               )

p2 = go.Scatter(x=X_test, 
                y=regr.predict(X_test),
                mode='lines',
                line=dict(color='blue', width=3)
                )

layout = go.Layout(xaxis=dict(ticks='', showticklabels=False,
                              zeroline=False),
                   yaxis=dict(ticks='', showticklabels=False,
                              zeroline=False),
                   showlegend=False, hovermode='closest')

fig = go.Figure(data=[p1, p2], layout=layout)

py.offline.iplot(fig)

但是,我的输出看起来像

However, my output looks like

如果我要逐行关注网站,我会得到的:

If I were to follow the website by every line, I would get this:

from plotly import tools
import plotly as py
import plotly.graph_objs as go

py.offline.init_notebook_mode(connected=True)

# Create linear regression object
regr = linear_model.LinearRegression()

# Train the model using the training sets
regr.fit(X_train, y_train)

def data_to_plotly(x):
    k = []

    for i in range(0, len(x)):
        k.append(x[i][0])

    return k

p1 = go.Scatter(x=data_to_plotly(X_test), 
                y=y_test, 
                mode='markers',
                marker=dict(color='black')
               )

p2 = go.Scatter(x=data_to_plotly(X_test), 
                y=regr.predict(X_test),
                mode='lines',
                line=dict(color='blue', width=3)
                )

layout = go.Layout(xaxis=dict(ticks='', showticklabels=False,
                              zeroline=False),
                   yaxis=dict(ticks='', showticklabels=False,
                              zeroline=False),
                   showlegend=False, hovermode='closest')

fig = go.Figure(data=[p1, p2], layout=layout)

py.offline.iplot(fig)

但是它会产生以下错误:

But it would generate the following error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~\Anaconda3\lib\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance)
   2656             try:
-> 2657                 return self._engine.get_loc(key)
   2658             except KeyError:

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 0

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
<ipython-input-84-5895927e91e2> in <module>
     21     return k
     22 
---> 23 p1 = go.Scatter(x=data_to_plotly(X_test), 
     24                 y=y_test,
     25                 mode='markers',

<ipython-input-84-5895927e91e2> in data_to_plotly(x)
     17 
     18     for i in range(0, len(x)):
---> 19         k.append(x[i][0])
     20 
     21     return k

~\Anaconda3\lib\site-packages\pandas\core\frame.py in __getitem__(self, key)
   2925             if self.columns.nlevels > 1:
   2926                 return self._getitem_multilevel(key)
-> 2927             indexer = self.columns.get_loc(key)
   2928             if is_integer(indexer):
   2929                 indexer = [indexer]

~\Anaconda3\lib\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance)
   2657                 return self._engine.get_loc(key)
   2658             except KeyError:
-> 2659                 return self._engine.get_loc(self._maybe_cast_indexer(key))
   2660         indexer = self.get_indexer([key], method=method, tolerance=tolerance)
   2661         if indexer.ndim > 1 or indexer.size > 1:

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 0

我是Plotly的新手.我该如何解决?

I'm new to Plotly. How do I fix this?

我的X_test看起来像这样:

my X_test looks like this:

X_test

推荐答案

我看到X_train是一个数据框,Plotly实际上对熊猫很友好,plotly的

I see that X_train is a dataframe, Plotly is actually pretty friendly with Pandas, there are several examples in plotly's pandas example gallery so you don't have to deal with in between functions like data_to_plotly (that tutorial sadly looks quite outdated). In this case the scatters should look something like

p1 = go.Scatter(x=X_test['Explained by: GDP per capita'],
                y=y_test, # Assuming y_test is a numpy array or pandas series
                          # if it is also a dataframe you have to specify the column
                mode='markers',
                marker=dict(color='black')
               )

p2 = go.Scatter(x=X_test['Explained by: GDP per capita'],
                y=regr.predict(X_test),
                mode='lines',
                line=dict(color='blue', width=3)
                )

这篇关于如何绘制线性回归?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆