LogisticRegression:未知标签类型:在 python 中使用 sklearn 的“连续" [英] LogisticRegression: Unknown label type: 'continuous' using sklearn in python

查看:34
本文介绍了LogisticRegression:未知标签类型:在 python 中使用 sklearn 的“连续"的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有以下代码来测试 sklearn python 库的一些最流行的机器学习算法:

I have the following code to test some of most popular ML algorithms of sklearn python library:

import numpy as np
from sklearn                        import metrics, svm
from sklearn.linear_model           import LinearRegression
from sklearn.linear_model           import LogisticRegression
from sklearn.tree                   import DecisionTreeClassifier
from sklearn.neighbors              import KNeighborsClassifier
from sklearn.discriminant_analysis  import LinearDiscriminantAnalysis
from sklearn.naive_bayes            import GaussianNB
from sklearn.svm                    import SVC

trainingData    = np.array([ [2.3, 4.3, 2.5],  [1.3, 5.2, 5.2],  [3.3, 2.9, 0.8],  [3.1, 4.3, 4.0]  ])
trainingScores  = np.array( [3.4, 7.5, 4.5, 1.6] )
predictionData  = np.array([ [2.5, 2.4, 2.7],  [2.7, 3.2, 1.2] ])

clf = LinearRegression()
clf.fit(trainingData, trainingScores)
print("LinearRegression")
print(clf.predict(predictionData))

clf = svm.SVR()
clf.fit(trainingData, trainingScores)
print("SVR")
print(clf.predict(predictionData))

clf = LogisticRegression()
clf.fit(trainingData, trainingScores)
print("LogisticRegression")
print(clf.predict(predictionData))

clf = DecisionTreeClassifier()
clf.fit(trainingData, trainingScores)
print("DecisionTreeClassifier")
print(clf.predict(predictionData))

clf = KNeighborsClassifier()
clf.fit(trainingData, trainingScores)
print("KNeighborsClassifier")
print(clf.predict(predictionData))

clf = LinearDiscriminantAnalysis()
clf.fit(trainingData, trainingScores)
print("LinearDiscriminantAnalysis")
print(clf.predict(predictionData))

clf = GaussianNB()
clf.fit(trainingData, trainingScores)
print("GaussianNB")
print(clf.predict(predictionData))

clf = SVC()
clf.fit(trainingData, trainingScores)
print("SVC")
print(clf.predict(predictionData))

前两个工作正常,但在 LogisticRegression 调用中出现以下错误:

The first two works ok, but I got the following error in LogisticRegression call:

root@ubupc1:/home/ouhma# python stack.py 
LinearRegression
[ 15.72023529   6.46666667]
SVR
[ 3.95570063  4.23426243]
Traceback (most recent call last):
  File "stack.py", line 28, in <module>
    clf.fit(trainingData, trainingScores)
  File "/usr/local/lib/python2.7/dist-packages/sklearn/linear_model/logistic.py", line 1174, in fit
    check_classification_targets(y)
  File "/usr/local/lib/python2.7/dist-packages/sklearn/utils/multiclass.py", line 172, in check_classification_targets
    raise ValueError("Unknown label type: %r" % y_type)
ValueError: Unknown label type: 'continuous'

输入的数据和之前的调用一样,那么这里发生了什么?

The input data is the same as in the previous calls, so what is going on here?

顺便说一下,为什么 LinearRegression()SVR() 算法的第一次预测存在巨大差异 (15.72 vs 3.95)?

And by the way, why there is a huge diference in the first prediction of LinearRegression() and SVR() algorithms (15.72 vs 3.95)?

推荐答案

您正在将浮点数传递给分类器,该分类器将分类值作为目标向量.如果您将其转换为 int ,它将被接受为输入(尽管这样做是否正确会令人怀疑).

You are passing floats to a classifier which expects categorical values as the target vector. If you convert it to int it will be accepted as input (although it will be questionable if that's the right way to do it).

最好使用 scikit 的 labelEncoder 函数.

It would be better to convert your training scores by using scikit's labelEncoder function.

同样适用于您的 DecisionTree 和 KNeighbors 限定符.

The same is true for your DecisionTree and KNeighbors qualifier.

from sklearn import preprocessing
from sklearn import utils

lab_enc = preprocessing.LabelEncoder()
encoded = lab_enc.fit_transform(trainingScores)
>>> array([1, 3, 2, 0], dtype=int64)

print(utils.multiclass.type_of_target(trainingScores))
>>> continuous

print(utils.multiclass.type_of_target(trainingScores.astype('int')))
>>> multiclass

print(utils.multiclass.type_of_target(encoded))
>>> multiclass

这篇关于LogisticRegression:未知标签类型:在 python 中使用 sklearn 的“连续"的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆