在 Tensorflow 中,是否有一种简单的方法可以在模型检查点发生时注册回调函数? [英] In Tensorflow, is there a simple way to register a callback function for when model checkpointing happens?

查看:34
本文介绍了在 Tensorflow 中,是否有一种简单的方法可以在模型检查点发生时注册回调函数?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在查看 Tensorflow 实现的文本编码器模型的代码,该模型称为 Skip-thought Vector 模型:https://github.com/tensorflow/models/tree/master/skip_thoughts.

I was looking at the code from a Tensorflow's implementation of a text encoder model called Skip-thought Vector models: https://github.com/tensorflow/models/tree/master/skip_thoughts.

训练脚本中的代码包含以下内容:

The code in the training script contains the following:

saver = tf.train.Saver()

tf.contrib.slim.learning.train(
  train_op=train_tensor,
  logdir=FLAGS.train_dir,
  graph=g,
  global_step=model.global_step,
  number_of_steps=training_config.number_of_steps,
  save_summaries_secs=training_config.save_summaries_secs,
  saver=saver,
  save_interval_secs=training_config.save_model_secs)

显然,每 training_config.save_model_secs 秒都会保存一个模型检查点.

Apparently, a model checkpoint is saved every training_config.save_model_secs seconds.

我想知道是否有一种方法可以注册某种回调函数,该函数在每次发生模型检查点后都会被调用.具体来说,我想将模型检查点移动/复制到其他一些网络位置.

I wonder if there's a way to register some kind of callback function that gets called after model checkpointing happens every time. Specifically, I want to move/copy the model checkpoints to some other network locations.

推荐答案

CheckpointSaverListener(参见 code) 将是一种方法,但它需要使用 MonitoredTrainingSession 而不是依赖于slim api,因此您需要重新实现 slim.train 方法的一些逻辑.

The CheckpointSaverListener (see code) would be a way to go but it would require to use a MonitoredTrainingSession instead of relying on the slim api and therefore you would need to reimplement some of the logic of slim.train method.

# Class example from TensorFlow link above
class ExampleCheckpointSaverListerner(CheckpointSaverListener):
    def after_save(self, session, global_step_value):
        print('Done writing checkpoint.')
    ...

# Pseudo-code to illustrate how to use it
your_hooks = [ExampleCheckpointSaverListerner()]
step = 0
with tf.train.MonitoredTrainingSession(checkpoint_dir=FLAGS.train_dir,
                                       chief_only_hooks=your_hooks) as sess:
    # Your training loop
    while step < num_loop:
        _, step = sess.run([train_tensor, model.global_step], ...)

这篇关于在 Tensorflow 中,是否有一种简单的方法可以在模型检查点发生时注册回调函数?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆