如何获取Spark DataFrame中每行列表中最大值的索引?[PySpark] [英] How to get the index of the highest value in a list per row in a Spark DataFrame? [PySpark]

查看:127
本文介绍了如何获取Spark DataFrame中每行列表中最大值的索引?[PySpark]的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已经完成了LDA主题建模并将其存储在 lda_model 中.

I have done LDA topic modelling and have it stored in lda_model.

转换原始输入数据集后,我检索到一个DataFrame.列之一是topicDistribution,其中该行属于LDA模型中每个主题的概率.因此,我想获取每行列表中最大值的索引.

After transforming my original input dataset I retrieve a DataFrame. One of the columns is the topicDistribution where the probability of this row belonging to each topic from the LDA model. I therefore want to get the index of the maximul value in the list per row.

df -- | 'list_of_words' | 'index ' | 'topicDistribution' | 
       ['product','...']     0       [0.08,0.2,0.4,0.0001]
          .....             ...         ........

我想对df进行转换,以便添加另一列,即每行topicDistribution列表的argmax.

I want to transform df such that an additional column is added which is the argmax of the topicDistribution list per row.

df_transformed --  | 'list_of_words' | 'index' | 'topicDistribution' | 'topicID' |
                    ['product','...']     0     [0.08,0.2,0.4,0.0001]      2
                       ......            ....         .....              ....

我该怎么做?

推荐答案

您可以创建用户定义的函数以获取最大值的索引

You can create a user defined function to get the index of the maximum

from pyspark.sql import functions as f
from pyspark.sql.types import IntegerType

max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
df = df.withColumn("topicID", max_index("topicDistribution"))


示例

>>> from pyspark.sql import functions as f
>>> from pyspark.sql.types import IntegerType 
>>> df = spark.createDataFrame([{"topicDistribution": [0.2, 0.3, 0.5]}])
>>> df.show()
+-----------------+
|topicDistribution|
+-----------------+
|  [0.2, 0.3, 0.5]|
+-----------------+

>>> max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
>>> df.withColumn("topicID", max_index("topicDistribution")).show()
+-----------------+-------+
|topicDistribution|topicID|
+-----------------+-------+
|  [0.2, 0.3, 0.5]|      2|
+-----------------+-------+


由于您提到 topicDistribution 中的列表是numpy数组,因此可以如下更新 max_index udf :

Since you mentioned that the lists in topicDistribution are numpy arrays, you can update the max_index udf as follows:

max_index = f.udf(lambda x: x.tolist().index(max(x)), IntegerType())

这篇关于如何获取Spark DataFrame中每行列表中最大值的索引?[PySpark]的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆