使用 Matplotlib 的 pyplot 绘制分离 2 个类的决策边界 [英] Plotting a decision boundary separating 2 classes using Matplotlib's pyplot
问题描述
我真的可以使用一个技巧来帮助我绘制决策边界以将数据分类.我通过 Python NumPy 创建了一些示例数据(来自高斯分布).在这种情况下,每个数据点都是一个二维坐标,即由 2 行组成的 1 列向量.例如,
I could really use a tip to help me plotting a decision boundary to separate to classes of data. I created some sample data (from a Gaussian distribution) via Python NumPy. In this case, every data point is a 2D coordinate, i.e., a 1 column vector consisting of 2 rows. E.g.,
[ 1
2 ]
假设我有 2 个类,class1 和 class2,我通过下面的代码为 class1 创建了 100 个数据点,为 class2 创建了 100 个数据点(分配给变量 x1_samples 和 x2_samples).
Let's assume I have 2 classes, class1 and class2, and I created 100 data points for class1 and 100 data points for class2 via the code below (assigned to the variables x1_samples and x2_samples).
mu_vec1 = np.array([0,0])
cov_mat1 = np.array([[2,0],[0,2]])
x1_samples = np.random.multivariate_normal(mu_vec1, cov_mat1, 100)
mu_vec1 = mu_vec1.reshape(1,2).T # to 1-col vector
mu_vec2 = np.array([1,2])
cov_mat2 = np.array([[1,0],[0,1]])
x2_samples = np.random.multivariate_normal(mu_vec2, cov_mat2, 100)
mu_vec2 = mu_vec2.reshape(1,2).T
当我为每个类绘制数据点时,它看起来像这样:
When I plot the data points for each class, it would look like this:
现在,我想出了一个决策边界方程来分隔两个类,并想将其添加到图中.但是,我不确定如何绘制此函数:
Now, I came up with an equation for an decision boundary to separate both classes and would like to add it to the plot. However, I am not really sure how I can plot this function:
def decision_boundary(x_vec, mu_vec1, mu_vec2):
g1 = (x_vec-mu_vec1).T.dot((x_vec-mu_vec1))
g2 = 2*( (x_vec-mu_vec2).T.dot((x_vec-mu_vec2)) )
return g1 - g2
我真的很感激任何帮助!
I would really appreciate any help!
直观地(如果我的数学计算正确)我希望决策边界在我绘制函数时看起来有点像这条红线......
Intuitively (If I did my math right) I would expect the decision boundary to look somewhat like this red line when I plot the function...
推荐答案
这些都是很好的建议,非常感谢您的帮助!我最终通过分析解决了方程,这是我最终得到的解决方案(我只是想发布它以供将来参考:
Those were some great suggestions, thanks a lot for your help! I ended up solving the equation analytically and this is the solution I ended up with (I just want to post it for future reference:
# 2-category classification with random 2D-sample data
# from a multivariate normal distribution
import numpy as np
from matplotlib import pyplot as plt
def decision_boundary(x_1):
""" Calculates the x_2 value for plotting the decision boundary."""
return 4 - np.sqrt(-x_1**2 + 4*x_1 + 6 + np.log(16))
# Generating a Gaussion dataset:
# creating random vectors from the multivariate normal distribution
# given mean and covariance
mu_vec1 = np.array([0,0])
cov_mat1 = np.array([[2,0],[0,2]])
x1_samples = np.random.multivariate_normal(mu_vec1, cov_mat1, 100)
mu_vec1 = mu_vec1.reshape(1,2).T # to 1-col vector
mu_vec2 = np.array([1,2])
cov_mat2 = np.array([[1,0],[0,1]])
x2_samples = np.random.multivariate_normal(mu_vec2, cov_mat2, 100)
mu_vec2 = mu_vec2.reshape(1,2).T # to 1-col vector
# Main scatter plot and plot annotation
f, ax = plt.subplots(figsize=(7, 7))
ax.scatter(x1_samples[:,0], x1_samples[:,1], marker='o', color='green', s=40, alpha=0.5)
ax.scatter(x2_samples[:,0], x2_samples[:,1], marker='^', color='blue', s=40, alpha=0.5)
plt.legend(['Class1 (w1)', 'Class2 (w2)'], loc='upper right')
plt.title('Densities of 2 classes with 25 bivariate random patterns each')
plt.ylabel('x2')
plt.xlabel('x1')
ftext = 'p(x|w1) ~ N(mu1=(0,0)^t, cov1=I)
p(x|w2) ~ N(mu2=(1,1)^t, cov2=I)'
plt.figtext(.15,.8, ftext, fontsize=11, ha='left')
# Adding decision boundary to plot
x_1 = np.arange(-5, 5, 0.1)
bound = decision_boundary(x_1)
plt.plot(x_1, bound, 'r--', lw=3)
x_vec = np.linspace(*ax.get_xlim())
x_1 = np.arange(0, 100, 0.05)
plt.show()
代码可以在这里
我还有一个方便的函数,用于为实现 fit
和 predict
方法的分类器绘制决策区域,例如,scikit-learn 中的分类器,如果无法通过分析找到解决方案.可以在此处找到更详细的说明.
I also have a convenience function for plotting decision regions for classifiers that implement a fit
and predict
method, e.g., the classifiers in scikit-learn, which is useful if the solution cannot be found analytically. A more detailed description how it works can be found here.
这篇关于使用 Matplotlib 的 pyplot 绘制分离 2 个类的决策边界的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!