SessionRunHook 的成员函数被调用的顺序是什么? [英] What is the sequence of SessionRunHook's member function to be called?

查看:32
本文介绍了SessionRunHook 的成员函数被调用的顺序是什么?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

阅读API DOC后,我也无法理解 SessionRunHook 的用法.比如SessionRunHook的成员顺序是什么要调用的函数?是 after_create_session ->before_run ->开始 ->after_run ->结束 ?而且我也找不到详细例子的教程,有更详细的说明吗?

After read the API DOC, I also can't understand the usage of SessionRunHook. For example, what is the sequence of SessionRunHook's member function to be called? Is it after_create_session -> before_run -> begin -> after_run -> end ? And I can't find the tutorial with detailed examples, is there more detailed explanation?

推荐答案

您可以在这里找到教程,有点长但你可以跳过构建网络的部分.或者,您可以根据我的经验阅读下面的小摘要.

You can find a tutorial here, a little long but you can jump the part of building the network. Or you can read my small summary below, based on my experiance.

首先,应该使用 MonitoredSession 而不是普通的 Session.

First, MonitoredSession should be used instead of normal Session.

SessionRunHook 扩展了 session.run()MonitoredSession 的调用.

A SessionRunHook extends session.run() calls for the MonitoredSession.

然后可以找到一些常见的SessionRunHook此处.一个简单的方法是 LoggingTensorHook,但您可能希望在导入后添加以下行,以便在运行时查看日志:

Then some common SessionRunHook classes can be found here. A simple one is LoggingTensorHook but you might want to add the following line after your imports for seeing the logs when running:

tf.logging.set_verbosity(tf.logging.INFO)

或者您可以选择实现自己的 SessionRunHook 类.一个简单的来自 cifar10 教程

Or you have option to implement your own SessionRunHook class. A simple one is from cifar10 tutorial

class _LoggerHook(tf.train.SessionRunHook):
  """Logs loss and runtime."""

  def begin(self):
    self._step = -1
    self._start_time = time.time()

  def before_run(self, run_context):
    self._step += 1
    return tf.train.SessionRunArgs(loss)  # Asks for loss value.

  def after_run(self, run_context, run_values):
    if self._step % FLAGS.log_frequency == 0:
      current_time = time.time()
      duration = current_time - self._start_time
      self._start_time = current_time

      loss_value = run_values.results
      examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
      sec_per_batch = float(duration / FLAGS.log_frequency)

      format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
      print (format_str % (datetime.now(), self._step, loss_value,
                           examples_per_sec, sec_per_batch))

其中 loss 在类之外定义.这个 _LoggerHook 使用 print 打印信息,而 LoggingTensorHook 使用 tf.logging.INFO.

where loss is defined outside the class. This _LoggerHook uses print to print the information while LoggingTensorHook uses tf.logging.INFO.

最后,为了更好地理解它是如何工作的,执行顺序由带有 MonitoredSession 这里:

At last, for better understanding how it works, the execution order is presented by pseudocode with MonitoredSession here:

  call hooks.begin()
  sess = tf.Session()
  call hooks.after_create_session()
  while not stop is requested:  # py code: while not mon_sess.should_stop():
    call hooks.before_run()
    try:
      results = sess.run(merged_fetches, feed_dict=merged_feeds)
    except (errors.OutOfRangeError, StopIteration):
      break
    call hooks.after_run()
  call hooks.end()
  sess.close()

希望这会有所帮助.

这篇关于SessionRunHook 的成员函数被调用的顺序是什么?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆