使用pyspark + databricks时如何绘制相关热图 [英] How to plot correlation heatmap when using pyspark+databricks
问题描述
我正在研究数据砖中的pyspark。我想生成一个相关热图。假设这是我的数据:
myGraph = spark.createDataFrame([(1.3,2.1,3.0),
(2.5,4.6,3.1),
(6.5、7.2、10.0),
['col1','col2','col3'])
这是我的代码:
import pyspark来自pyspark.sql的
import SparkSession
进口matplotlib.pyplot as plt
进口熊猫as pd
进口numpy as np
从ggplot import *
从pyspark .ml.feature导入pyspark.ml.stat中的VectorAssembler
导入pyspark.mllib.stat中的相关
导入统计信息
myGraph = spark.createDataFrame([(1.3,2.1, 3.0),
(2.5、4.6、3.1),
(6.5、7.2、10.0)],
['col1','col2','col3'])
vector_col = corr_features
汇编程序= VectorAssembler(inputCols = ['col1','col2','col3'],
outputCol = vector_col)
myGraph_vector = assembler.transform(myGraph).select(vector_col)
matrix = Correlation.corr(myGraph_vector,vector_col)
matrix.collect()[0] [皮尔逊({})。format(vector_col)]。values
直到这里,我可以得到相关性矩阵。结果如下:
现在我的问题了是:
- 如何将矩阵传输到数据帧?我已经尝试过
因为我刚刚研究了pyspark和数据块。 ggplot或matplotlib都可以解决我的问题。
解决方案我认为您会感到困惑的地方是:
matrix.collect()[0] [ pearson({})。format(vector_col)]。values
调用密集矩阵的.values会为您提供所有值的列表,但实际上您要查找的是代表相关性的列表的列表
从pyspark.ml导入matplotlib.pyplot作为plt
从fespark导入VectorAssembler
.ml.stat import Correlation
列= ['col1','col2','col3']
myGraph = spark.createDataFrame([(1.3,2.1,3.0 ),
(2.5、4.6、3.1),
(6.5、7.2、10.0),
列)
vector_col = corr_features
汇编程序= VectorAssembler( inputCols = ['col1','col2','col3'],
outputCol = vector_col)
myGraph_vector = as sembler.transform(myGraph).select(vector_col)
matrix = Correlation.corr(myGraph_vector,vector_col)
直到现在,它基本上就是您的代码。而不是调用.value,您应该使用.toArray()。tolist()来获取表示相关矩阵的列表的列表:
matrix = Correlation.corr(myGraph_vector,vector_col).collect()[0] [0]
corrmatrix = matrix.toArray()。tolist()
print(corrmatrix)
输出:
[ [1.0,0.9582184104641529,0.9780872729407004],[0.9582184104641529,1.0,0.8776695567739841],[0.9780872729407004,0.8776695567739841,1.0]] b $ b
此方法的优点是您可以轻松地将列表列表转换为数据框:
df = spark。 createDataFrame(corrmatrix,columns)
df.show()
输出:
+ ------------------ + ---------- -------- + ------------------ +
| col1 | col2 | col3 |
+ ------------------ + ------------------ + ------- ----------- +
| 1.0 | 0.9582184104641529 | 0.9780872729407004 |
| 0.9582184104641529 | 1.0 | 0.8776695567739841 |
| 0.9780872729407004 | 0.8776695567739841 | 1.0 |
+ ------------------ + ------------------ + ------- ----------- +
回答第二个问题。绘制热点图的众多解决方案之一(例如此或这与出生)。
def plot_corr_matrix(correlations,attr,fig_no):
fig = plt.figure(fig_no)
ax = fig.add_subplot(111)
ax.set_title(指定属性的相关矩阵)
ax.set_xticklabels([ + attr)
ax.set_yticklabels([ ] + attr)
cax = ax.matshow(correlations,vmax = 1,vmin = -1)
fig.colorbar(cax)
plt.show()
plot_corr_matrix(corrmatrix,列,234)
I am studying pyspark in databricks. I want to generate a correlation heatmap. Let's say this is my data:
myGraph=spark.createDataFrame([(1.3,2.1,3.0), (2.5,4.6,3.1), (6.5,7.2,10.0)], ['col1','col2','col3'])
And this is my code:
import pyspark from pyspark.sql import SparkSession import matplotlib.pyplot as plt import pandas as pd import numpy as np from ggplot import * from pyspark.ml.feature import VectorAssembler from pyspark.ml.stat import Correlation from pyspark.mllib.stat import Statistics myGraph=spark.createDataFrame([(1.3,2.1,3.0), (2.5,4.6,3.1), (6.5,7.2,10.0)], ['col1','col2','col3']) vector_col = "corr_features" assembler = VectorAssembler(inputCols=['col1','col2','col3'], outputCol=vector_col) myGraph_vector = assembler.transform(myGraph).select(vector_col) matrix = Correlation.corr(myGraph_vector, vector_col) matrix.collect()[0]["pearson({})".format(vector_col)].values
Until here, I can get the correlation matrix. The result looks like:
Now my problems are:
- How to transfer matrix to data frame? I have tried the methods of How to convert DenseMatrix to spark DataFrame in pyspark? and How to get correlation matrix values pyspark. But it does not work for me.
- How to generate a correlation heatmap which looks like:
Because I just studied pyspark and databricks. ggplot or matplotlib are both OK for my problem.
解决方案I think the point where you get confused is:
matrix.collect()[0]["pearson({})".format(vector_col)].values
Calling .values of a densematrix gives you a list of all values, but what you are actually looking for is a list of list representing correlation matrix.
import matplotlib.pyplot as plt from pyspark.ml.feature import VectorAssembler from pyspark.ml.stat import Correlation columns = ['col1','col2','col3'] myGraph=spark.createDataFrame([(1.3,2.1,3.0), (2.5,4.6,3.1), (6.5,7.2,10.0)], columns) vector_col = "corr_features" assembler = VectorAssembler(inputCols=['col1','col2','col3'], outputCol=vector_col) myGraph_vector = assembler.transform(myGraph).select(vector_col) matrix = Correlation.corr(myGraph_vector, vector_col)
Until now it was basically your code. Instead of calling .values you should use .toArray().tolist() to get a list of lists representing the correlation matrix:
matrix = Correlation.corr(myGraph_vector, vector_col).collect()[0][0] corrmatrix = matrix.toArray().tolist() print(corrmatrix)
Output:
[[1.0, 0.9582184104641529, 0.9780872729407004], [0.9582184104641529, 1.0, 0.8776695567739841], [0.9780872729407004, 0.8776695567739841, 1.0]]
The advantage of this approach is that you can turn a list of lists easily into a dataframe:
df = spark.createDataFrame(corrmatrix,columns) df.show()
Output:
+------------------+------------------+------------------+ | col1| col2| col3| +------------------+------------------+------------------+ | 1.0|0.9582184104641529|0.9780872729407004| |0.9582184104641529| 1.0|0.8776695567739841| |0.9780872729407004|0.8776695567739841| 1.0| +------------------+------------------+------------------+
To answer your second question. Just one of the many solutions to plot a heatmap (like this or this even better with seaborn).
def plot_corr_matrix(correlations,attr,fig_no): fig=plt.figure(fig_no) ax=fig.add_subplot(111) ax.set_title("Correlation Matrix for Specified Attributes") ax.set_xticklabels(['']+attr) ax.set_yticklabels(['']+attr) cax=ax.matshow(correlations,vmax=1,vmin=-1) fig.colorbar(cax) plt.show() plot_corr_matrix(corrmatrix, columns, 234)
这篇关于使用pyspark + databricks时如何绘制相关热图的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!