从Keras多类模型获取混淆矩阵 [英] Get Confusion Matrix From a Keras Multiclass Model
本文介绍了从Keras多类模型获取混淆矩阵的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我正在用Keras构建多类模型.
I am building a multiclass model with Keras.
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, callbacks=[checkpoint], validation_data=(X_test, y_test)) # starts training
这是我的测试数据(文本数据)的样子.
Here is how my test data looks like (it's text data).
X_test
Out[25]:
array([[621, 139, 549, ..., 0, 0, 0],
[621, 139, 543, ..., 0, 0, 0]])
y_test
Out[26]:
array([[0, 0, 1],
[0, 1, 0]])
生成预测后...
predictions = model.predict(X_test)
predictions
Out[27]:
array([[ 0.29071924, 0.2483743 , 0.46090645],
[ 0.29566404, 0.45295066, 0.25138539]], dtype=float32)
我做了以下操作来获得混淆矩阵.
I did the following to get the confusion matrix.
y_pred = (predictions > 0.5)
confusion_matrix(y_test, y_pred)
Traceback (most recent call last):
File "<ipython-input-38-430e012b2078>", line 1, in <module>
confusion_matrix(y_test, y_pred)
File "/Users/abrahammathew/anaconda3/lib/python3.6/site-packages/sklearn/metrics/classification.py", line 252, in confusion_matrix
raise ValueError("%s is not supported" % y_type)
ValueError: multilabel-indicator is not supported
但是,我遇到了以上错误.
However, I am getting the above error.
在Keras中建立多类神经网络时,如何获得混淆矩阵?
How can I get a confusion matrix when doing a multiclass neural network in Keras?
推荐答案
您输入的confusion_matrix
必须是一个整数数组,而不是一种热编码.
Your input to confusion_matrix
must be an array of int not one hot encodings.
matrix = metrics.confusion_matrix(y_test.argmax(axis=1), y_pred.argmax(axis=1))
这篇关于从Keras多类模型获取混淆矩阵的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文