如何正确使用 tf.metrics.accuracy? [英] How to properly use tf.metrics.accuracy?
问题描述
我在使用 tf.metrics
中的 accuracy
函数处理以 logits 作为输入的多分类问题时遇到了一些麻烦.
I have some trouble using the accuracy
function from tf.metrics
for a multiple classification problem with logits as input.
我的模型输出如下:
logits = [[0.1, 0.5, 0.4],
[0.8, 0.1, 0.1],
[0.6, 0.3, 0.2]]
我的标签是一个热编码向量:
And my labels are one hot encoded vectors:
labels = [[0, 1, 0],
[1, 0, 0],
[0, 0, 1]]
当我尝试做类似 tf.metrics.accuracy(labels, logits)
的事情时,它永远不会给出正确的结果.我显然做错了什么,但我不知道是什么.
When I try to do something like tf.metrics.accuracy(labels, logits)
it never gives the correct result. I am obviously doing something wrong but I can't figure what it is.
推荐答案
TL;DR
准确度函数 tf.metrics.accuracy 计算预测的频率根据它创建的两个局部变量匹配标签:total
和 count
,它们用于计算 logits
匹配 labels 的频率代码>.
The accuracy function tf.metrics.accuracy calculates how often predictions matches labels based on two local variables it creates: total
and count
, that are used to compute the frequency with which logits
matches labels
.
acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1),
predictions=tf.argmax(logits,1))
print(sess.run([acc, acc_op]))
print(sess.run([acc]))
# Output
#[0.0, 0.66666669]
#[0.66666669]
- acc(准确度):仅使用
total
和count
返回指标,不更新指标. - acc_op(更新):更新指标.
- acc (accuracy): simply returns the metrics using
total
andcount
, doesnt update the metrics. - acc_op (update up): updates the metrics.
要了解 acc 返回 0.0
的原因,请查看以下详细信息.
To understand why the acc returns 0.0
, go through the details below.
使用简单示例的详细信息:
logits = tf.placeholder(tf.int64, [2,3])
labels = tf.Variable([[0, 1, 0], [1, 0, 1]])
acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1),
predictions=tf.argmax(logits,1))
初始化变量:
由于metrics.accuracy
创建了两个局部变量total
和count
,我们需要调用local_variables_initializer()
初始化它们.
Since metrics.accuracy
creates two local variables total
and count
, we need to call local_variables_initializer()
to initialize them.
sess = tf.Session()
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
stream_vars = [i for i in tf.local_variables()]
print(stream_vars)
#[<tf.Variable 'accuracy/total:0' shape=() dtype=float32_ref>,
# <tf.Variable 'accuracy/count:0' shape=() dtype=float32_ref>]
了解更新操作和准确度计算:
print('acc:',sess.run(acc, {logits:[[0,1,0],[1,0,1]]}))
#acc: 0.0
print('[total, count]:',sess.run(stream_vars))
#[total, count]: [0.0, 0.0]
以上返回 0.0 的准确性,因为 total
和 count
为零,尽管提供了匹配的输入.
The above returns 0.0 for accuracy as total
and count
are zeros, inspite of giving matching inputs.
print('ops:', sess.run(acc_op, {logits:[[0,1,0],[1,0,1]]}))
#ops: 1.0
print('[total, count]:',sess.run(stream_vars))
#[total, count]: [2.0, 2.0]
使用新输入,在调用更新操作时计算准确度.注意:由于所有的 logits 和标签都匹配,我们得到了 1.0 的准确度,并且局部变量 total
和 count
实际上给出了 total 正确预测
和总比较
.
With the new inputs, the accuracy is calculated when the update op is called. Note: since all the logits and labels match, we get accuracy of 1.0 and the local variables total
and count
actually give total correctly predicted
and the total comparisons made
.
现在我们用新输入(不是更新操作)调用 accuracy
:
Now we call accuracy
with the new inputs (not the update ops):
print('acc:', sess.run(acc,{logits:[[1,0,0],[0,1,0]]}))
#acc: 1.0
准确度调用不会使用新输入更新指标,它只是使用两个局部变量返回值.注意:在这种情况下,logits 和标签不匹配.现在再次调用更新操作:
Accuracy call doesnt update the metrics with the new inputs, it just returns the value using the two local variables. Note: the logits and labels dont match in this case. Now calling update ops again:
print('op:',sess.run(acc_op,{logits:[[0,1,0],[0,1,0]]}))
#op: 0.75
print('[total, count]:',sess.run(stream_vars))
#[total, count]: [3.0, 4.0]
指标更新为新输入
有关如何在训练期间使用指标以及如何在验证期间重置它们的更多信息,可以找到 此处.
For more information on how to use the metrics during training and how to reset them during validation, can be found here.
这篇关于如何正确使用 tf.metrics.accuracy?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!