如何让 Keras 仅对验证数据计算某个指标? [英] How to make Keras compute a certain metric on validation data only?

查看:29
本文介绍了如何让 Keras 仅对验证数据计算某个指标?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在 TensorFlow 1.14.0 中使用 tf.keras.我已经实现了一个计算量非常大的自定义指标,如果我只是将它添加到作为 model.compile(..., metrics=[...])<提供的指标列表中,它会减慢训练过程的速度/代码>.

I'm using tf.keras with TensorFlow 1.14.0. I have implemented a custom metric that is quite computationally intensive and it slows down the training process if I simply add it to the list of metrics provided as model.compile(..., metrics=[...]).

如何让 Keras 在训练迭代期间跳过度量的计算,但在每个 epoch 结束时根据验证数据计算(并打印)?

How do I make Keras skip computation of the metric during training iterations but compute it on validation data (and print it) at the end of each epoch?

推荐答案

为此,您可以在度量计算中创建一个 tf.Variable 来确定计算是否继续进行,然后在使用回调运行测试时更新它.例如

To do this you can create a tf.Variable in the metric calculation that determines if the calculation goes ahead and then update it when a test is run using a callback. e.g.

class MyCustomMetric(tf.keras.metrics.Metrics):

    def __init__(self, **kwargs):
        # Initialise as normal and add flag variable for when to run computation
        super(MyCustomMetric, self).__init__(**kwargs)
        self.metric_variable = self.add_weight(name='metric_varaible', initializer='zeros')
        self.update_metric = tf.Variable(False)

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Use conditional to determine if computation is done
        if self.update_metric:
            # run computation
            self.metric_variable.assign_add(computation_result)

    def result(self):
        return self.metric_variable

    def reset_states(self):
        self.metric_variable.assign(0.)

class ToggleMetrics(tf.keras.callbacks.Callback):
    '''On test begin (i.e. when evaluate() is called or 
     validation data is run during fit()) toggle metric flag '''
    def on_test_begin(self, logs):
        for metric in self.model.metrics:
            if 'MyCustomMetric' in metric.name:
                metric.on.assign(True)
    def on_test_end(self,  logs):
        for metric in self.model.metrics:
            if 'MyCustomMetric' in metric.name:
                metric.on.assign(False)

这篇关于如何让 Keras 仅对验证数据计算某个指标?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆