如何正确使用 tf.metrics.accuracy? [英] How to properly use tf.metrics.accuracy?

查看:46
本文介绍了如何正确使用 tf.metrics.accuracy?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在使用 tf.metrics 中的 accuracy 函数处理以 logits 作为输入的多分类问题时遇到了一些麻烦.

I have some trouble using the accuracy function from tf.metrics for a multiple classification problem with logits as input.

我的模型输出如下:

logits = [[0.1, 0.5, 0.4],
          [0.8, 0.1, 0.1],
          [0.6, 0.3, 0.2]]

我的标签是一个热编码向量:

And my labels are one hot encoded vectors:

labels = [[0, 1, 0],
          [1, 0, 0],
          [0, 0, 1]]

当我尝试做类似 tf.metrics.accuracy(labels, logits) 的事情时,它永远不会给出正确的结果.我显然做错了什么,但我不知道是什么.

When I try to do something like tf.metrics.accuracy(labels, logits) it never gives the correct result. I am obviously doing something wrong but I can't figure what it is.

推荐答案

TL;DR

准确度函数 tf.metrics.accuracy 计算预测的频率根据它创建的两个局部变量匹配标签:totalcount,它们用于计算 logits 匹配 labels 的频率.

The accuracy function tf.metrics.accuracy calculates how often predictions matches labels based on two local variables it creates: total and count, that are used to compute the frequency with which logits matches labels.

acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1), 
                                  predictions=tf.argmax(logits,1))

print(sess.run([acc, acc_op]))
print(sess.run([acc]))
# Output
#[0.0, 0.66666669]
#[0.66666669]

  • acc(准确度):仅使用 totalcount 返回指标,不更新指标.
  • acc_op(更新):更新指标.
    • acc (accuracy): simply returns the metrics using total and count, doesnt update the metrics.
    • acc_op (update up): updates the metrics.
    • 要了解 acc 返回 0.0 的原因,请查看以下详细信息.

      To understand why the acc returns 0.0, go through the details below.

      使用简单示例的详细信息:

      logits = tf.placeholder(tf.int64, [2,3])
      labels = tf.Variable([[0, 1, 0], [1, 0, 1]])
      
      acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1),   
                                        predictions=tf.argmax(logits,1))
      

      初始化变量:

      由于metrics.accuracy创建了两个局部变量totalcount,我们需要调用local_variables_initializer()初始化它们.

      Since metrics.accuracy creates two local variables total and count, we need to call local_variables_initializer() to initialize them.

      sess = tf.Session()
      
      sess.run(tf.local_variables_initializer())
      sess.run(tf.global_variables_initializer())
      
      stream_vars = [i for i in tf.local_variables()]
      print(stream_vars)
      
      #[<tf.Variable 'accuracy/total:0' shape=() dtype=float32_ref>,
      # <tf.Variable 'accuracy/count:0' shape=() dtype=float32_ref>]
      

      了解更新操作和准确度计算:

      print('acc:',sess.run(acc, {logits:[[0,1,0],[1,0,1]]}))
      #acc: 0.0
      
      print('[total, count]:',sess.run(stream_vars)) 
      #[total, count]: [0.0, 0.0]
      

      以上返回 0.0 的准确性,因为 totalcount 为零,尽管提供了匹配的输入.

      The above returns 0.0 for accuracy as total and count are zeros, inspite of giving matching inputs.

      print('ops:', sess.run(acc_op, {logits:[[0,1,0],[1,0,1]]})) 
      #ops: 1.0
      
      print('[total, count]:',sess.run(stream_vars)) 
      #[total, count]: [2.0, 2.0]
      

      使用新输入,在调用更新操作时计算准确度.注意:由于所有的 logits 和标签都匹配,我们得到了 1.0 的准确度,并且局部变量 totalcount 实际上给出了 total 正确预测总比较.

      With the new inputs, the accuracy is calculated when the update op is called. Note: since all the logits and labels match, we get accuracy of 1.0 and the local variables total and count actually give total correctly predicted and the total comparisons made.

      现在我们用新输入(不是更新操作)调用 accuracy:

      Now we call accuracy with the new inputs (not the update ops):

      print('acc:', sess.run(acc,{logits:[[1,0,0],[0,1,0]]}))
      #acc: 1.0
      

      准确度调用不会使用新输入更新指标,它只是使用两个局部变量返回值.注意:在这种情况下,logits 和标签不匹配.现在再次调用更新操作:

      Accuracy call doesnt update the metrics with the new inputs, it just returns the value using the two local variables. Note: the logits and labels dont match in this case. Now calling update ops again:

      print('op:',sess.run(acc_op,{logits:[[0,1,0],[0,1,0]]}))
      #op: 0.75 
      print('[total, count]:',sess.run(stream_vars)) 
      #[total, count]: [3.0, 4.0]
      

      指标更新为新输入

      有关如何在训练期间使用指标以及如何在验证期间重置它们的更多信息,可以找到 此处.

      For more information on how to use the metrics during training and how to reset them during validation, can be found here.

      这篇关于如何正确使用 tf.metrics.accuracy?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆