如何将误报率实现为 TF 指标 [英] How to implement false positive rate as TF metric
本文介绍了如何将误报率实现为 TF 指标的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我正在尝试将一些指标添加到 BERT 样式的模型中,但在 tf.metrics 方面遇到了困难.对于大多数指标,您可以使用 tf.metrics.mean 非常简单,但对于像误报率这样的指标则不然.我知道有 tf.metrics.false_positives 和 tf.metrics.true_negatives,但由于 tf.metrics 也有相关的操作,你不能只做 fpr = fp/(fp + tn)
.这是怎么回事?
I'm trying to add some metrics to a BERT-style model, but struggling with tf.metrics. For most metrics it's pretty straightforward that you can use tf.metrics.mean, but for a metric like false positive rate it's not. I know there is tf.metrics.false_positives and tf.metrics.true_negatives, but since tf.metrics also have an associated op, you can't just do fpr = fp / (fp + tn)
. How does one go about this?
推荐答案
代码如下:
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.metrics_impl import _aggregate_across_towers
from tensorflow.python.ops.metrics_impl import true_negatives
from tensorflow.python.ops.metrics_impl import false_positives
from tensorflow.python.ops.metrics_impl import _remove_squeezable_dimensions
def false_positive_rate(labels,
predictions,
weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
if context.executing_eagerly():
raise RuntimeError('tf.metrics.recall is not supported is not '
'supported when eager execution is enabled.')
with variable_scope.variable_scope(name, 'false_alarm',
(predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
false_p, false_positives_update_op = false_positives(
labels,
predictions,
weights,
metrics_collections=None,
updates_collections=None,
name=None)
true_n, true_negatives_update_op = true_negatives(
labels,
predictions,
weights,
metrics_collections=None,
updates_collections=None,
name=None)
def compute_false_positive_rate(true_n, false_p, name):
return array_ops.where(
math_ops.greater(true_n + false_p, 0),
math_ops.div(false_p, true_n + false_p), 0, name)
def once_across_towers(_, true_n, false_p):
return compute_false_positive_rate(true_n, false_p, 'value')
false_positive_rate = _aggregate_across_towers(
metrics_collections, once_across_towers, true_n, false_p)
update_op = compute_false_positive_rate(true_negatives_update_op,
false_positives_update_op, 'update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return false_positive_rate, update_op
这篇关于如何将误报率实现为 TF 指标的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文