如何获取Spark DataFrame中每行列表中最大值的索引?[PySpark] [英] How to get the index of the highest value in a list per row in a Spark DataFrame? [PySpark]
问题描述
我已经完成了LDA主题建模并将其存储在 lda_model
中.
I have done LDA topic modelling and have it stored in lda_model
.
转换原始输入数据集后,我检索到一个DataFrame.列之一是topicDistribution,其中该行属于LDA模型中每个主题的概率.因此,我想获取每行列表中最大值的索引.
After transforming my original input dataset I retrieve a DataFrame. One of the columns is the topicDistribution where the probability of this row belonging to each topic from the LDA model. I therefore want to get the index of the maximul value in the list per row.
df -- | 'list_of_words' | 'index ' | 'topicDistribution' |
['product','...'] 0 [0.08,0.2,0.4,0.0001]
..... ... ........
我想对df进行转换,以便添加另一列,即每行topicDistribution列表的argmax.
I want to transform df such that an additional column is added which is the argmax of the topicDistribution list per row.
df_transformed -- | 'list_of_words' | 'index' | 'topicDistribution' | 'topicID' |
['product','...'] 0 [0.08,0.2,0.4,0.0001] 2
...... .... ..... ....
我该怎么做?
推荐答案
您可以创建用户定义的函数以获取最大值的索引
You can create a user defined function to get the index of the maximum
from pyspark.sql import functions as f
from pyspark.sql.types import IntegerType
max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
df = df.withColumn("topicID", max_index("topicDistribution"))
示例
>>> from pyspark.sql import functions as f
>>> from pyspark.sql.types import IntegerType
>>> df = spark.createDataFrame([{"topicDistribution": [0.2, 0.3, 0.5]}])
>>> df.show()
+-----------------+
|topicDistribution|
+-----------------+
| [0.2, 0.3, 0.5]|
+-----------------+
>>> max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
>>> df.withColumn("topicID", max_index("topicDistribution")).show()
+-----------------+-------+
|topicDistribution|topicID|
+-----------------+-------+
| [0.2, 0.3, 0.5]| 2|
+-----------------+-------+
由于您提到 topicDistribution
中的列表是numpy数组,因此可以如下更新 max_index
udf
:
Since you mentioned that the lists in topicDistribution
are numpy arrays, you can update the max_index
udf
as follows:
max_index = f.udf(lambda x: x.tolist().index(max(x)), IntegerType())
这篇关于如何获取Spark DataFrame中每行列表中最大值的索引?[PySpark]的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!