SessionRunHook 的成员函数被调用的顺序是什么? [英] What is the sequence of SessionRunHook's member function to be called?
问题描述
阅读API DOC后,我也无法理解 SessionRunHook 的用法.比如SessionRunHook的成员顺序是什么要调用的函数?是 after_create_session ->before_run ->开始 ->after_run ->结束
?而且我也找不到详细例子的教程,有更详细的说明吗?
After read the API DOC, I also can't understand the usage of SessionRunHook. For example, what is the sequence of SessionRunHook's member
function to be called? Is it after_create_session -> before_run -> begin -> after_run -> end
?
And I can't find the tutorial with detailed examples, is there more detailed explanation?
推荐答案
您可以在这里找到教程,有点长但你可以跳过构建网络的部分.或者,您可以根据我的经验阅读下面的小摘要.
You can find a tutorial here, a little long but you can jump the part of building the network. Or you can read my small summary below, based on my experiance.
首先,应该使用 MonitoredSession
而不是普通的 Session
.
First, MonitoredSession
should be used instead of normal Session
.
SessionRunHook 扩展了 session.run()
对 MonitoredSession
的调用.
A SessionRunHook extends
session.run()
calls for theMonitoredSession
.
然后可以找到一些常见的SessionRunHook
类此处.一个简单的方法是 LoggingTensorHook
,但您可能希望在导入后添加以下行,以便在运行时查看日志:
Then some common SessionRunHook
classes can be found here. A simple one is LoggingTensorHook
but you might want to add the following line after your imports for seeing the logs when running:
tf.logging.set_verbosity(tf.logging.INFO)
或者您可以选择实现自己的 SessionRunHook
类.一个简单的来自 cifar10 教程>
Or you have option to implement your own SessionRunHook
class. A simple one is from cifar10 tutorial
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
其中 loss
在类之外定义.这个 _LoggerHook
使用 print
打印信息,而 LoggingTensorHook
使用 tf.logging.INFO
.
where loss
is defined outside the class. This _LoggerHook
uses print
to print the information while LoggingTensorHook
uses tf.logging.INFO
.
最后,为了更好地理解它是如何工作的,执行顺序由带有 MonitoredSession
这里:
At last, for better understanding how it works, the execution order is presented by pseudocode with MonitoredSession
here:
call hooks.begin()
sess = tf.Session()
call hooks.after_create_session()
while not stop is requested: # py code: while not mon_sess.should_stop():
call hooks.before_run()
try:
results = sess.run(merged_fetches, feed_dict=merged_feeds)
except (errors.OutOfRangeError, StopIteration):
break
call hooks.after_run()
call hooks.end()
sess.close()
希望这会有所帮助.
这篇关于SessionRunHook 的成员函数被调用的顺序是什么?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!