如何使用 tf.train.MonitoredTrainingSession 仅恢复某些变量 [英] How to use tf.train.MonitoredTrainingSession to restore only certain variables

查看:29
本文介绍了如何使用 tf.train.MonitoredTrainingSession 仅恢复某些变量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

如何告诉 tf.train.MonitoredTrainingSession 仅恢复变量的一个子集,并对其余变量执行初始化?

从 cifar10 教程开始..https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_train.py

Starting with the cifar10 tutorial .. https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_train.py

.. 我创建了要恢复和初始化的变量列表,并使用传递给 MonitoredTrainingSession 的 Scaffold 指定它们:

.. I created lists of the variables to restore and initialize, and specified them using a Scaffold that I pass to the MonitoredTrainingSession:

  restoration_saver = Saver(var_list=restore_vars)
  restoration_scaffold = Scaffold(init_op=variables_initializer(init_vars),
                                  ready_op=constant([]),
                                  saver=restoration_saver)

但这会产生以下错误:

运行时错误:初始化操作未使模型为 local_init 做好准备.初始化操作:group_deps,初始化 fn:无,错误:变量未初始化:conv2a/T、conv2b/T、[...]

RuntimeError: Init operations did not make model ready for local_init. Init op: group_deps, init fn: None, error: Variables not initialized: conv2a/T, conv2b/T, [...]

...其中错误消息中列出的未初始化变量是我的init_vars"列表中的变量.

.. where the uninitialized variables listed in the error message are the variables in my "init_vars" list.

异常由 SessionManager.prepare_session() 引发.该方法的源代码似乎表明,如果会话是从检查点恢复的,则不会运行 init_op.所以看起来你可以恢复变量或初始化变量,但不能两者兼而有之.

The exception is raised by SessionManager.prepare_session(). The source code for that method seems to indicate that if the session is restored from a checkpoint, then the init_op is not run. So it looks like you can either have restored variables or initialized variables, but not both.

推荐答案

好吧,正如我所怀疑的,通过基于现有的 tf.training.SessionManager 实现一个新的 RefinementSessionManager 类,我得到了我想要的东西.这两个类几乎相同,除了我修改了 prepare_session 方法以调用 init_op,无论模型是否从检查点加载.

OK so as I suspected, I got what I wanted by implementing a new RefinementSessionManager class based on the existing tf.training.SessionManager. The two classes are almost identical, except I modified the prepare_session method to call the init_op regardless of whether the model was loaded from a checkpoint.

这允许我从检查点加载变量列表并在 init_op 中初始化剩余的变量.

This allows me to load a list of variables from the checkpoint and initialize the remaining variables in the init_op.

我的 prepare_session 方法是这样的:

My prepare_session method is this:

  def prepare_session(self, master, init_op=None, saver=None,
                  checkpoint_dir=None, wait_for_checkpoint=False,
                  max_wait_secs=7200, config=None, init_feed_dict=None,
                  init_fn=None):

    sess, is_loaded_from_checkpoint = self._restore_checkpoint(
    master,
    saver,
    checkpoint_dir=checkpoint_dir,
    wait_for_checkpoint=wait_for_checkpoint,
    max_wait_secs=max_wait_secs,
    config=config)

    # [removed] if not is_loaded_from_checkpoint:
    # we still want to run any supplied initialization on models that
    # were loaded from checkpoint.

    if not is_loaded_from_checkpoint and init_op is None and not init_fn and self._local_init_op is None:
      raise RuntimeError("Model is not initialized and no init_op or "
                     "init_fn or local_init_op was given")
    if init_op is not None:
      sess.run(init_op, feed_dict=init_feed_dict)
    if init_fn:
      init_fn(sess)

    # [...]

希望这对其他人有所帮助.

Hope this helps somebody else.

这篇关于如何使用 tf.train.MonitoredTrainingSession 仅恢复某些变量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆