使用 Matplotlib 的 pyplot 绘制分离 2 个类的决策边界 [英] Plotting a decision boundary separating 2 classes using Matplotlib's pyplot

查看:45
本文介绍了使用 Matplotlib 的 pyplot 绘制分离 2 个类的决策边界的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我真的可以使用一个技巧来帮助我绘制决策边界以将数据分类.我通过 Python NumPy 创建了一些示例数据(来自高斯分布).在这种情况下,每个数据点都是一个二维坐标,即由 2 行组成的 1 列向量.例如,

I could really use a tip to help me plotting a decision boundary to separate to classes of data. I created some sample data (from a Gaussian distribution) via Python NumPy. In this case, every data point is a 2D coordinate, i.e., a 1 column vector consisting of 2 rows. E.g.,

[ 1
  2 ]

假设我有 2 个类,class1 和 class2,我通过下面的代码为 class1 创建了 100 个数据点,为 class2 创建了 100 个数据点(分配给变量 x1_samples 和 x2_samples).

Let's assume I have 2 classes, class1 and class2, and I created 100 data points for class1 and 100 data points for class2 via the code below (assigned to the variables x1_samples and x2_samples).

mu_vec1 = np.array([0,0])
cov_mat1 = np.array([[2,0],[0,2]])
x1_samples = np.random.multivariate_normal(mu_vec1, cov_mat1, 100)
mu_vec1 = mu_vec1.reshape(1,2).T # to 1-col vector

mu_vec2 = np.array([1,2])
cov_mat2 = np.array([[1,0],[0,1]])
x2_samples = np.random.multivariate_normal(mu_vec2, cov_mat2, 100)
mu_vec2 = mu_vec2.reshape(1,2).T

当我为每个类绘制数据点时,它看起来像这样:

When I plot the data points for each class, it would look like this:

现在,我想出了一个决策边界方程来分隔两个类,并想将其添加到图中.但是,我不确定如何绘制此函数:

Now, I came up with an equation for an decision boundary to separate both classes and would like to add it to the plot. However, I am not really sure how I can plot this function:

def decision_boundary(x_vec, mu_vec1, mu_vec2):
    g1 = (x_vec-mu_vec1).T.dot((x_vec-mu_vec1))
    g2 = 2*( (x_vec-mu_vec2).T.dot((x_vec-mu_vec2)) )
    return g1 - g2

我真的很感激任何帮助!

I would really appreciate any help!

直观地(如果我的数学计算正确)我希望决策边界在我绘制函数时看起来有点像这条红线......

Intuitively (If I did my math right) I would expect the decision boundary to look somewhat like this red line when I plot the function...

推荐答案

这些都是很好的建议,非常感谢您的帮助!我最终通过分析解决了方程,这是我最终得到的解决方案(我只是想发布它以供将来参考:

Those were some great suggestions, thanks a lot for your help! I ended up solving the equation analytically and this is the solution I ended up with (I just want to post it for future reference:

# 2-category classification with random 2D-sample data 
# from a multivariate normal distribution 

import numpy as np
from matplotlib import pyplot as plt

def decision_boundary(x_1):
    """ Calculates the x_2 value for plotting the decision boundary."""
    return 4 - np.sqrt(-x_1**2 + 4*x_1 + 6 + np.log(16))

# Generating a Gaussion dataset:
# creating random vectors from the multivariate normal distribution 
# given mean and covariance 
mu_vec1 = np.array([0,0])
cov_mat1 = np.array([[2,0],[0,2]])
x1_samples = np.random.multivariate_normal(mu_vec1, cov_mat1, 100)
mu_vec1 = mu_vec1.reshape(1,2).T # to 1-col vector

mu_vec2 = np.array([1,2])
cov_mat2 = np.array([[1,0],[0,1]])
x2_samples = np.random.multivariate_normal(mu_vec2, cov_mat2, 100)
mu_vec2 = mu_vec2.reshape(1,2).T # to 1-col vector

# Main scatter plot and plot annotation
f, ax = plt.subplots(figsize=(7, 7))
ax.scatter(x1_samples[:,0], x1_samples[:,1], marker='o', color='green', s=40, alpha=0.5)
ax.scatter(x2_samples[:,0], x2_samples[:,1], marker='^', color='blue', s=40, alpha=0.5)
plt.legend(['Class1 (w1)', 'Class2 (w2)'], loc='upper right') 
plt.title('Densities of 2 classes with 25 bivariate random patterns each')
plt.ylabel('x2')
plt.xlabel('x1')
ftext = 'p(x|w1) ~ N(mu1=(0,0)^t, cov1=I)
p(x|w2) ~ N(mu2=(1,1)^t, cov2=I)'
plt.figtext(.15,.8, ftext, fontsize=11, ha='left')

# Adding decision boundary to plot
x_1 = np.arange(-5, 5, 0.1)
bound = decision_boundary(x_1)
plt.plot(x_1, bound, 'r--', lw=3)

x_vec = np.linspace(*ax.get_xlim())
x_1 = np.arange(0, 100, 0.05)

plt.show()

代码可以在这里

我还有一个方便的函数,用于为实现 fitpredict 方法的分类器绘制决策区域,例如,scikit-learn 中的分类器,如果无法通过分析找到解决方案.可以在此处找到更详细的说明.

I also have a convenience function for plotting decision regions for classifiers that implement a fit and predict method, e.g., the classifiers in scikit-learn, which is useful if the solution cannot be found analytically. A more detailed description how it works can be found here.

这篇关于使用 Matplotlib 的 pyplot 绘制分离 2 个类的决策边界的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆