在 pyspark 中操作数组时出现类型错误 [英] TypeError while manipulating arrays in pyspark
本文介绍了在 pyspark 中操作数组时出现类型错误的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我正在尝试计算user_features"和movie_features"之间的点积(元素积的总和):
I am trying to compute dot product (sum of element products) between 'user_features' and 'movie_features':
+------+-------+--------------------+--------------------+
|userId|movieId| user_features| movie_features|
+------+-------+--------------------+--------------------+
| 18| 1|[0.0, 0.5, 0.0, 0...|[1, 0, 0, 0, 0, 1...|
| 18| 2|[0.1, 0.0, 0.0, 0...|[1, 0, 0, 0, 0, 0...|
| 18| 3|[0.2, 0.0, 0.3, 0...|[0, 0, 0, 0, 0, 1...|
| 18| 4|[0.0, 0.1, 0.0, 0...|[0, 0, 0, 0, 0, 1...|
+------+-------+--------------------+--------------------+
数据类型:
df.printSchema()
_____________________________________________
root
|-- userId: integer (nullable = true)
|-- movieId: integer (nullable = true)
|-- user_features: array (nullable = false)
| |-- element: double (containsNull = true)
|-- movie_features: array (nullable = false)
| |-- element: float (containsNull = true)
None
我用这个
class Solution:
"""
Data reading, pre-processing...
"""
@udf("array<double>")
def miltiply(self, x, y):
if x and y:
return [float(a * b) for a, b in zip(x, y)]
def get_dot_product(self):
df = self.user_DF.crossJoin(self.movies_DF)
output = df.withColumn("zipxy", self.miltiply("user_features", "movie_features")) \
.withColumn('sumxy', sum([F.col('zipxy').getItem(i) for i in range(20)]))
出现以下错误:
TypeError: Invalid argument, not a string or column: <__main__.Solution instance at 0x000000000A777EC8>类型<类型'实例'>.对于列文字,请使用lit"、array"、struct"或create_map"函数.
我错过了什么?我是通过 udf
完成的,因为我使用的是 Spark 1.6,因此不能使用 aggregate
或 zip_with
函数.
What am I missing? I am doing it by udf
since I am using Spark 1.6 therefor can't use aggregate
or zip_with
functions.
推荐答案
如果你可以使用 numpy
那么
If you can use the numpy
then
df = spark.createDataFrame([(18, 1, [1, 0, 1], [1, 1, 1])]).toDF('userId','movieId','user_features','movie_features')
import numpy as np
df.rdd.map(lambda x: (x[0], x[1], x[2], x[3], float(np.dot(np.array(x[2]), np.array(x[3]))))).toDF(df.columns + ['dot']).show()
+------+-------+-------------+--------------+---+
|userId|movieId|user_features|movie_features|dot|
+------+-------+-------------+--------------+---+
| 18| 1| [1, 0, 1]| [1, 1, 1]|2.0|
+------+-------+-------------+--------------+---+
这篇关于在 pyspark 中操作数组时出现类型错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文