使用 one-hot 代码的 Tensorflow 混淆矩阵 [英] Tensorflow confusion matrix using one-hot code

查看:46
本文介绍了使用 one-hot 代码的 Tensorflow 混淆矩阵的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我使用 RNN 进行多类分类,这是我的 RNN 主要代码:

I have multi-class classification using RNN and here is my main code for RNN:

def RNN(x, weights, biases):
    x = tf.unstack(x, input_size, 1)
    lstm_cell = rnn.BasicLSTMCell(num_unit, forget_bias=1.0, state_is_tuple=True) 
    stacked_lstm = rnn.MultiRNNCell([lstm_cell]*lstm_size, state_is_tuple=True) 
    outputs, states = tf.nn.static_rnn(stacked_lstm, x, dtype=tf.float32)

    return tf.matmul(outputs[-1], weights) + biases

logits = RNN(X, weights, biases)
prediction = tf.nn.softmax(logits)

cost =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(cost)

correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

我必须将所有输入分为 6 个类,每个类由一个热代码标签组成,如下所示:

I have to classify all inputs to 6 classes and each of classes is composed of one-hot code label as the follow:

happy = [1, 0, 0, 0, 0, 0]
angry = [0, 1, 0, 0, 0, 0]
neutral = [0, 0, 1, 0, 0, 0]
excited = [0, 0, 0, 1, 0, 0]
embarrassed = [0, 0, 0, 0, 1, 0]
sad = [0, 0, 0, 0, 0, 1]

问题是我无法使用 tf.confusion_matrix() 函数打印混淆矩阵.

The problem is I cannot print confusion matrix using tf.confusion_matrix() function.

有没有办法使用这些标签打印混淆矩阵?

Is there any way to print confusion matrix using those labels?

如果没有,只有在需要打印混淆矩阵时,如何才能将 one-hot 代码转换为整数索引?

If not, how can I convert one-hot code to integer indices only when I need to print confusion matrix?

推荐答案

您不能使用 one-hot 向量作为 labelspredictions 的输入参数生成混淆矩阵.您必须直接为其提供包含标签的一维张量.

You cannot generate confusion matrix using one-hot vectors as input parameters of labels and predictions. You will have to supply it a 1D tensor containing your labels directly.

要将您的一个热向量转换为普通标签,请使用 argmax功能:

To convert your one hot vector to normal label, make use of argmax function:

label = tf.argmax(one_hot_tensor, axis = 1)

之后你可以像这样打印你的confusion_matrix:

After that you can print your confusion_matrix like this:

import tensorflow as tf

num_classes = 2
prediction_arr = tf.constant([1,  1, 1, 1,  0, 0, 0, 0,  1, 1])
labels_arr     = tf.constant([0,  1, 1, 1,  1, 1, 1, 1,  0, 0])

confusion_matrix = tf.confusion_matrix(labels_arr, prediction_arr, num_classes)
with tf.Session() as sess:
    print(confusion_matrix.eval())

输出:

[[0 3]
 [4 3]]

这篇关于使用 one-hot 代码的 Tensorflow 混淆矩阵的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆