如何在向量列中找到最大值的索引? [英] How to find the index of the maximum value in a vector column?

查看:127
本文介绍了如何在向量列中找到最大值的索引?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个具有以下结构的Spark DataFrame:

I have a Spark DataFrame with the following structure:

root
|-- distribution: vector (nullable = true)

+--------------------+
|   topicDistribution|
+--------------------+
|     [0.1, 0.2]     |
|     [0.3, 0.2]     |
|     [0.5, 0.2]     |
|     [0.1, 0.7]     |
|     [0.1, 0.8]     |
|     [0.1, 0.9]     |
+--------------------+

我的问题是:如何添加每行具有最大值索引的列?

My question is: How to add a column with the index of the maximum value for each row?

应该是这样的:

root
|-- distribution: vector (nullable = true)
|-- max_index: integer (nullable = true)

+--------------------+-----------+
|   topicDistribution| max_index |
+--------------------+-----------+
|     [0.1, 0.2]     |   1       | 
|     [0.3, 0.2]     |   0       | 
|     [0.5, 0.2]     |   0       | 
|     [0.1, 0.7]     |   1       | 
|     [0.1, 0.8]     |   1       | 
|     [0.1, 0.9]     |   1       | 
+--------------------+-----------+

非常感谢

我尝试了以下方法,但出现错误:

I tried the following method but I got an error:

import org.apache.spark.sql.functions.udf

val func = udf( (x: Vector[Double]) => x.indices.maxBy(x) )

df.withColumn("max_idx",func(($"topicDistribution"))).show()

错误说:

Exception in thread "main" org.apache.spark.sql.AnalysisException: 
cannot resolve 'UDF(topicDistribution)' due to data type mismatch: 
argument 1 requires array<double> type, however, '`topicDistribution`' 
is of vector type.;;

推荐答案

// create some sample data:
import org.apache.spark.mllib.linalg.{Vectors,Vector}
case class myrow(topics:Vector)

 val rdd = sc.parallelize(Array(myrow(Vectors.dense(0.1,0.2)),myrow(Vectors.dense(0.6,0.2))))
val mydf = sqlContext.createDataFrame(rdd)
mydf.show()
+----------+
|    topics|
+----------+
|[0.1, 0.2]|
|[0.6, 0.2]|
+----------+

// build the udf
import org.apache.spark.sql.functions.udf
val func = udf( (x:Vector) => x.toDense.values.toSeq.indices.maxBy(x.toDense.values) )


mydf.withColumn("max_idx",func($"topics")).show()
+----------+-------+
|    topics|max_idx|
+----------+-------+
|[0.1, 0.2]|      1|
|[0.6, 0.2]|      0|
+----------+-------+

//注意:对于您的特定用例,您可能必须将UDF更改为Vector而不是Seq//编辑为使用Vector而不是Seq作为您的原始问题和询问的问题

// note: you might have to change the UDF to be Vector instead of Seq for your particular use-case //edited to use Vector instead of Seq as you original question and your comment asked

这篇关于如何在向量列中找到最大值的索引?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆