线性回归中误差函数的3D图 [英] 3D-plot of the error function in a linear regression

查看:224
本文介绍了线性回归中误差函数的3D图的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想在视觉上绘制针对给定斜率和y截距计算的误差函数的3D图形,以进行线性回归. 该图将用于说明梯度下降的应用.

I would like to visually plot a 3D graph of the error function calculated for a given slope and y-intercept for a linear regression. This graph will be used to illustrate a gradient descent application.

让我们假设要用一条线对一组点进行建模.为此,我们将使用标准的y = mx + b线方程式,其中m是线的斜率,b是线的y轴截距.为了找到最佳数据线,我们需要找到最佳的斜率m和y截距b值集.

Let’s suppose we want to model a set of points with a line. To do this we’ll use the standard y=mx+b line equation where m is the line’s slope and b is the line’s y-intercept. To find the best line for our data, we need to find the best set of slope m and y-intercept b values.

解决此类问题的标准方法是定义一个误差函数(也称为成本函数),该误差函数可衡量给定线的良好"程度.此函数将接受(m,b)对,并根据行对数据的拟合程度返回错误值.为了计算给定线的误差,我们将遍历数据集中的每个(x,y)点,并对每个点的y值和候选线的y值之间的平方距离求和(以mx + b计算).通常,将此距离取平方,以确保该距离为正,并使我们的误差函数可微.在python中,计算给定行的错误将类似于:

A standard approach to solving this type of problem is to define an error function (also called a cost function) that measures how "good" a given line is. This function will take in a (m,b) pair and return an error value based on how well the line fits the data. To compute this error for a given line, we’ll iterate through each (x,y) point in the data set and sum the square distances between each point’s y value and the candidate line’s y value (computed at mx+b). It’s conventional to square this distance to ensure that it is positive and to make our error function differentiable. In python, computing the error for a given line will look like:

# y = mx + b
# m is slope, b is y-intercept
def computeErrorForLineGivenPoints(b, m, points):
    totalError = 0
    for i in range(0, len(points)):
        totalError += (points[i].y - (m * points[i].x + b)) ** 2
    return totalError / float(len(points))

由于误差函数由两个参数(m和b)组成,我们可以将其可视化为二维表面.

Since the error function consists of two parameters (m and b) we can visualize it as a two-dimensional surface.

现在我的问题是,我们如何使用python绘制此类3D图形?

这是构建3D图的基本代码.该代码段完全不在问题上下文中,但它显示了构建3D图的基础知识. 对于我的示例,我需要x轴为斜率,y轴为y轴截距,z轴为误差.

Here is a skeleton code to build a 3D plot. This code snippet is totally out of the question context but it show the basics for building a 3D plot. For my example i would need the x-axis being the slope, the y-axis being the y-intercept and the z-axis, the error.

有人可以帮我建立这样的图例吗?

Can someone help me build such example of graph ?

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import random

def fun(x, y):
  return x**2 + y

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
x = y = np.arange(-3.0, 3.0, 0.05)
X, Y = np.meshgrid(x, y)
zs = np.array([fun(x,y) for x,y in zip(np.ravel(X), np.ravel(Y))])
Z = zs.reshape(X.shape)

ax.plot_surface(X, Y, Z)

ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')

plt.show()

上面的代码产生以下图,这与我正在寻找的非常相似.

The above code produce the following plot, which is very similar to what i am looking for.

推荐答案

只需将fun替换为computeErrorForLineGivenPoints:

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import collections

def error(m, b, points):
    totalError = 0
    for i in range(0, len(points)):
        totalError += (points[i].y - (m * points[i].x + b)) ** 2
    return totalError / float(len(points))

x = y = np.arange(-3.0, 3.0, 0.05)
Point = collections.namedtuple('Point', ['x', 'y'])

m, b = 3, 2
noise = np.random.random(x.size)
points = [Point(xp, m*xp+b+err) for xp,err in zip(x, noise)]

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ms = np.linspace(2.0, 4.0, 10)
bs = np.linspace(1.5, 2.5, 10)

M, B = np.meshgrid(ms, bs)
zs = np.array([error(mp, bp, points) 
               for mp, bp in zip(np.ravel(M), np.ravel(B))])
Z = zs.reshape(M.shape)

ax.plot_surface(M, B, Z, rstride=1, cstride=1, color='b', alpha=0.5)

ax.set_xlabel('m')
ax.set_ylabel('b')
ax.set_zlabel('error')

plt.show()

产量

提示:我将computeErrorForLineGivenPoints重命名为error.通常,无需命名函数compute...,因为几乎所有函数都会计算某些东西.您也不需要指定"GivenPoints",因为函数签名显示points是自变量.如果程序中还有其他错误函数或变量,则line_errortotal_error可能是此函数的更好名称.

Tip: I renamed computeErrorForLineGivenPoints as error. Generally, there is no need to name a function compute... since almost all functions compute something. You also do not need to specify "GivenPoints" since the function signature shows that points is an argument. If you have other error functions or variables in your program, line_error or total_error might be a better name for this function.

这篇关于线性回归中误差函数的3D图的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆