使用可用的训练钩子在tf.estimator.DNNRegressor中实现提前停止 [英] Implement early stopping in tf.estimator.DNNRegressor using the available training hooks

查看:248
本文介绍了使用可用的训练钩子在tf.estimator.DNNRegressor中实现提前停止的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我是tensorflow的新手,并且希望使用可用的训练钩子在tf.estimator.DNNRegressor中实现早期停止 MNIST数据集的训练挂钩.如果在某些指定的步数上损失没有改善,则提前停止钩将停止训练. Tensorflow文档仅提供记录钩子的示例.有人可以写一个代码片段来实现它吗?

I am new to tensorflow and want to implement early stopping in tf.estimator.DNNRegressor with available training hooksTraining Hooks for the MNIST dataset. The early stopping hook will stop training if the loss does not improve for some specified number of steps. Tensorflow documentaton only provides example for Logging hooks. Can someone write a code snippet for implementing it?

推荐答案

这是EarlyStoppingHook示例实现:

import numpy as np
import tensorflow as tf
import logging
from tensorflow.python.training import session_run_hook


class EarlyStoppingHook(session_run_hook.SessionRunHook):
    """Hook that requests stop at a specified step."""

    def __init__(self, monitor='val_loss', min_delta=0, patience=0,
                 mode='auto'):
        """
        """
        self.monitor = monitor
        self.patience = patience
        self.min_delta = min_delta
        self.wait = 0
        if mode not in ['auto', 'min', 'max']:
            logging.warning('EarlyStopping mode %s is unknown, '
                            'fallback to auto mode.', mode, RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
        elif mode == 'max':
            self.monitor_op = np.greater
        else:
            if 'acc' in self.monitor:
                self.monitor_op = np.greater
            else:
                self.monitor_op = np.less

        if self.monitor_op == np.greater:
            self.min_delta *= 1
        else:
            self.min_delta *= -1

        self.best = np.Inf if self.monitor_op == np.less else -np.Inf

    def begin(self):
        # Convert names to tensors if given
        graph = tf.get_default_graph()
        self.monitor = graph.as_graph_element(self.monitor)
        if isinstance(self.monitor, tf.Operation):
            self.monitor = self.monitor.outputs[0]

    def before_run(self, run_context):  # pylint: disable=unused-argument
        return session_run_hook.SessionRunArgs(self.monitor)

    def after_run(self, run_context, run_values):
        current = run_values.results

        if self.monitor_op(current - self.min_delta, self.best):
            self.best = current
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                run_context.request_stop()

此实现基于 Keras实施.

要与CNN MNIST一起使用示例创建钩子并将其传递给train.

To use it with CNN MNIST example create hook and pass it to train.

early_stopping_hook = EarlyStoppingHook(monitor='sparse_softmax_cross_entropy_loss/value', patience=10)

mnist_classifier.train(
  input_fn=train_input_fn,
  steps=20000,
  hooks=[logging_hook, early_stopping_hook])

在此示例中,sparse_softmax_cross_entropy_loss/value是损失操作的名称.

Here sparse_softmax_cross_entropy_loss/value is the name of the loss op in that example.

使用估算器时,似乎没有正式"的查找损耗节点的方法(或者我找不到).

It looks like there is no "official" way of finding loss node when using estimators (or I can't find it).

对于DNNRegressor,此节点的名称为dnn/head/weighted_loss/Sum.

For the DNNRegressor this node has name dnn/head/weighted_loss/Sum.

以下是在图中找到它的方法:

Here is how to find it in the graph:

  1. 在模型目录中启动张量板.在我的情况下,我没有设置任何目录,因此estimator使用了临时目录并打印了以下行:
    WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpInj8SC
    启动张量板:

  1. Start tensorboard in model directory. In my case I didn't set any directory so estimator used temporary directory and printed this line:
    WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpInj8SC
    Start tensorboard:

tensorboard --logdir /tmp/tmpInj8SC

  • 在浏览器中打开它,然后导航到GRAPHS选项卡.

  • Open it in browser and navigate to GRAPHS tab.

    在图中查找损失.依次扩展块:dnnheadweighted_loss,然后单击Sum节点(注意,已连接名为loss的摘要节点).

    Find loss in the graph. Expand blocks in the sequence: dnnheadweighted_loss and click on the Sum node (note that there is summary node named loss connected to it).

    右侧信息窗口"中显示的名称是所选节点的名称,需要将其传递给monitor自变量pf EarlyStoppingHook.

    Name shown in the info "window" to the right is the name of the selected node, that need to be passed to monitor argument pf EarlyStoppingHook.

    默认情况下,DNNClassifier

    丢失节点具有相同的名称. DNNClassifierDNNRegressor都具有可选参数loss_reduction,该参数会影响丢失节点的名称和行为(默认为losses.Reduction.SUM).

    Loss node of the DNNClassifier has the same name by default. Both DNNClassifier and DNNRegressor have optional argument loss_reduction that influences loss node name and behavior (defaults to losses.Reduction.SUM).

    有一种无需查看图表即可发现损失的方法.
    您可以使用GraphKeys.LOSSES集合来获取损失.但是这种方式只有在训练开始后才有效.因此,您只能在挂钩中使用它.

    There is a way of finding loss without looking at the graph.
    You can use GraphKeys.LOSSES collection to get the loss. But this way will work only after training started. So you can use it only in a hook.

    例如,您可以从EarlyStoppingHook类中删除monitor自变量并更改其begin函数以始终使用集合中的第一个损失:

    For example you can remove monitor argument from the EarlyStoppingHook class and change its begin function to always use the first loss in the collection:

    self.monitor = tf.get_default_graph().get_collection(tf.GraphKeys.LOSSES)[0]
    

    您可能还需要检查集合中是否有丢失.

    You also probably need to check that there is a loss in the collection.

    这篇关于使用可用的训练钩子在tf.estimator.DNNRegressor中实现提前停止的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

  • 查看全文
    相关文章
    登录 关闭
    扫码关注1秒登录
    发送“验证码”获取 | 15天全站免登陆