在 Tensorflow 中,是否有一种简单的方法可以在模型检查点发生时注册回调函数? [英] In Tensorflow, is there a simple way to register a callback function for when model checkpointing happens?
问题描述
我正在查看 Tensorflow 实现的文本编码器模型的代码,该模型称为 Skip-thought Vector 模型:https://github.com/tensorflow/models/tree/master/skip_thoughts.
I was looking at the code from a Tensorflow's implementation of a text encoder model called Skip-thought Vector models: https://github.com/tensorflow/models/tree/master/skip_thoughts.
训练脚本中的代码包含以下内容:
The code in the training script contains the following:
saver = tf.train.Saver()
tf.contrib.slim.learning.train(
train_op=train_tensor,
logdir=FLAGS.train_dir,
graph=g,
global_step=model.global_step,
number_of_steps=training_config.number_of_steps,
save_summaries_secs=training_config.save_summaries_secs,
saver=saver,
save_interval_secs=training_config.save_model_secs)
显然,每 training_config.save_model_secs
秒都会保存一个模型检查点.
Apparently, a model checkpoint is saved every training_config.save_model_secs
seconds.
我想知道是否有一种方法可以注册某种回调函数,该函数在每次发生模型检查点后都会被调用.具体来说,我想将模型检查点移动/复制到其他一些网络位置.
I wonder if there's a way to register some kind of callback function that gets called after model checkpointing happens every time. Specifically, I want to move/copy the model checkpoints to some other network locations.
推荐答案
CheckpointSaverListener
(参见 code) 将是一种方法,但它需要使用 MonitoredTrainingSession
而不是依赖于slim api,因此您需要重新实现 slim.train
方法的一些逻辑.
The CheckpointSaverListener
(see code) would be a way to go but it would require to use a MonitoredTrainingSession
instead of relying on the slim api and therefore you would need to reimplement some of the logic of slim.train
method.
# Class example from TensorFlow link above
class ExampleCheckpointSaverListerner(CheckpointSaverListener):
def after_save(self, session, global_step_value):
print('Done writing checkpoint.')
...
# Pseudo-code to illustrate how to use it
your_hooks = [ExampleCheckpointSaverListerner()]
step = 0
with tf.train.MonitoredTrainingSession(checkpoint_dir=FLAGS.train_dir,
chief_only_hooks=your_hooks) as sess:
# Your training loop
while step < num_loop:
_, step = sess.run([train_tensor, model.global_step], ...)
这篇关于在 Tensorflow 中,是否有一种简单的方法可以在模型检查点发生时注册回调函数?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!