如何从TensorFlow预测中获取类标签 [英] How to get class labels from TensorFlow prediction

查看:853
本文介绍了如何从TensorFlow预测中获取类标签的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在TF中有一个分类模型,可以获取下一类的概率列表(preds).现在,我要选择最高的元素(argmax),并显示其类标签.

I have a classification model in TF and can get a list of probabilities for the next class (preds). Now I want to select the highest element (argmax) and display its class label.

这似乎很愚蠢,但是如何获得与预测张量中的位置匹配的类标签?

This may seems silly, but how can I get the class label that matches a position in the predictions tensor?

        feed_dict={g['x']: current_char}
        preds, state = sess.run([g['preds'],g['final_state']], feed_dict)
        prediction = tf.argmax(preds, 1)

preds为我提供了每个班级的预测向量.当然,必须有一种简单的方法来仅输出最可能的类(标签)吗?

preds gives me a vector of predictions for each class. Surely there must be an easy way to just output the most likely class (label)?

有关我的模型的一些信息:

Some info about my model:

x = tf.placeholder(tf.int32, [None, num_steps], name='input_placeholder')
y = tf.placeholder(tf.int32, [None, 1], name='labels_placeholder')
batch_size = batch_size = tf.shape(x)[0]  
x_one_hot = tf.one_hot(x, num_classes)
rnn_inputs = [tf.squeeze(i, squeeze_dims=[1]) for i in
              tf.split(x_one_hot, num_steps, 1)] 

tmp = tf.stack(rnn_inputs)
print(tmp.get_shape())
tmp2 = tf.transpose(tmp, perm=[1, 0, 2])
print(tmp2.get_shape())
rnn_inputs = tmp2


with tf.variable_scope('softmax'):
    W = tf.get_variable('W', [state_size, num_classes])
    b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0))


rnn_outputs = rnn_outputs[:, num_steps - 1, :]
rnn_outputs = tf.reshape(rnn_outputs, [-1, state_size])
y_reshaped = tf.reshape(y, [-1])
logits = tf.matmul(rnn_outputs, W) + b
predictions = tf.nn.softmax(logits)

推荐答案

您可以为此使用tf.reduce_max().我会推荐您此答案. 让我知道它是否有效-否则将进行编辑.

You can use tf.reduce_max() for this. I would refer you to this answer. Let me know if it works - will edit if it doesn't.

这篇关于如何从TensorFlow预测中获取类标签的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆