在Keras中使用tf.metrics? [英] Use tf.metrics in Keras?

查看:318
本文介绍了在Keras中使用tf.metrics?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我对 specificity_at_sensitivity 特别感兴趣.浏览 Keras文档:

from keras import metrics

model.compile(loss='mean_squared_error',
              optimizer='sgd',
              metrics=[metrics.mae, metrics.categorical_accuracy])

但是看起来metrics列表必须具有arity 2函数,接受(y_true, y_pred)并返回单个张量值.

But it looks like the metrics list must have functions of arity 2, accepting (y_true, y_pred) and returning a single tensor value.

目前这是我的工作方式:

Currently here is how I do things:

from sklearn.metrics import confusion_matrix

predictions = model.predict(x_test)
y_test = np.argmax(y_test, axis=-1)
predictions = np.argmax(predictions, axis=-1)
c = confusion_matrix(y_test, predictions)
print('Confusion matrix:\n', c)
print('sensitivity', c[0, 0] / (c[0, 1] + c[0, 0]))
print('specificity', c[1, 1] / (c[1, 1] + c[1, 0]))

这种方法的缺点是,只有在培训结束后我才能得到我关心的输出.宁愿每10个周期左右获取一次指标.

The disadvantage of this approach, is I only get the output I care about when training has finished. Would prefer to get metrics every 10 epochs or so.

推荐答案

我在 tf.metrics.specificity_at_sensitivity 非常感兴趣>,我建议采用以下解决方法(灵感来自 BogdanRuzh's 解决方案):

I've found a related issue on github, and it seems that tf.metrics are still not supported by Keras models. However, in case you are very interested in using tf.metrics.specificity_at_sensitivity, I would suggest the following workaround (inspired by BogdanRuzh's solution):

def specificity_at_sensitivity(sensitivity, **kwargs):
    def metric(labels, predictions):
        # any tensorflow metric
        value, update_op = tf.metrics.specificity_at_sensitivity(labels, predictions, sensitivity, **kwargs)

        # find all variables created for this metric
        metric_vars = [i for i in tf.local_variables() if 'specificity_at_sensitivity' in i.name.split('/')[2]]

        # Add metric variables to GLOBAL_VARIABLES collection.
        # They will be initialized for new session.
        for v in metric_vars:
            tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, v)

        # force to update metric values
        with tf.control_dependencies([update_op]):
            value = tf.identity(value)
            return value
    return metric


model.compile(loss='mean_squared_error',
              optimizer='sgd',
              metrics=[metrics.mae,
                       metrics.categorical_accuracy,
                       specificity_at_sensitivity(0.5)])


更新:

您可以使用 model.evaluate 来获取训练后的指标.

You can use model.evaluate to retrieve the metrics after training.

这篇关于在Keras中使用tf.metrics?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆