具有多个标签的Logistic回归PySpark MLlib问题 [英] Logistic Regression PySpark MLlib issue with multiple labels

查看:86
本文介绍了具有多个标签的Logistic回归PySpark MLlib问题的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试创建LogisticRegression模型(LogisticRegressionWithSGD),但出现错误

I am trying to create a LogisticRegression model (LogisticRegressionWithSGD), but its getting an error of

org.apache.spark.SparkException: Input validation failed.

如果我给它二进制输入(0,1而不是0,1,2),它将成功.

If I give it binary input (0,1 instead of 0,1,2) it does succeed.

示例输入:

parsed_data = [LabeledPoint(0.0, [4.6,3.6,1.0,0.2]),
LabeledPoint(0.0, [5.7,4.4,1.5,0.4]),
LabeledPoint(1.0, [6.7,3.1,4.4,1.4]),
LabeledPoint(0.0, [4.8,3.4,1.6,0.2]),
LabeledPoint(2.0, [4.4,3.2,1.3,0.2])]

代码: model = LogisticRegressionWithSGD.train(parsed_data)

spark中的Logistic回归模型应该仅用于二进制分类吗?

Is the Logistic Regression model in spark supposed to be for binary classification only?

推荐答案

尽管文档中内容不清楚(您必须深入研究

Although not clear from the documentation (you have to dig in to the source code to realize it), LogisticRegressionWithSGD works only with binary data; for multinomial regression, you should use LogisticRegressionWithLBFGS:

 from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel, LogisticRegressionWithSGD
 from pyspark.mllib.regression import LabeledPoint
 parsed_data = [LabeledPoint(0.0, [4.6,3.6,1.0,0.2]),
                LabeledPoint(0.0, [5.7,4.4,1.5,0.4]),
                LabeledPoint(1.0, [6.7,3.1,4.4,1.4]),
                LabeledPoint(0.0, [4.8,3.4,1.6,0.2]),
                LabeledPoint(2.0, [4.4,3.2,1.3,0.2])]     

 model = LogisticRegressionWithSGD.train(sc.parallelize(parsed_data)) # gives error:
 # org.apache.spark.SparkException: Input validation failed.

 model = LogisticRegressionWithLBFGS.train(sc.parallelize(parsed_data), numClasses=3)  # works OK

这篇关于具有多个标签的Logistic回归PySpark MLlib问题的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆