如何在 SessionRunHook 中使用 tf.train.Saver? [英] How to use tf.train.Saver in SessionRunHook?

查看:23
本文介绍了如何在 SessionRunHook 中使用 tf.train.Saver?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我训练了很多子模型,每个子模型都是最后一个模型的一部分.然后我想使用那些预训练的子模型来初始化最后一个模型的参数.我尝试使用 SessionRunHook 加载其他 ckpt 文件的模型参数以初始化最后一个模型.我尝试了以下代码但失败了.希望得到一些建议.谢谢!错误信息是:

I have trained many sub-models, each sub-models is a part of the last model. And then I want to use those pretrained sub models to initial the last model's parameters. I try to use SessionRunHook to load other ckpt file's model parameters to initial the last model's. I tried the follow code but failed. Hope some advices. Thanks! The error info is:

Traceback (most recent call last):
  File "train_high_api_local.py", line 282, in <module>
    tf.app.run()
  File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 124, in run
    _sys.exit(main(argv))
  File "train_high_api_local.py", line 266, in main
    clf_.train(input_fn=lambda: read_file([tables[0]], epochs_per_eval), steps=None, hooks=[hook_test])     # input yield: x, y
  File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 314, in train
  .......
  File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 674, in create_session
    hook.after_create_session(self.tf_sess, self.coord)
  File "train_high_api_local.py", line 102, in after_create_session
    saver = tf.train.Saver([ti])    # TODO: ERROR INFO:  Graph is finalized and cannot be modified.
  .......
  File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3135, in create_op
    self._check_not_finalized()
  File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2788, in _check_not_finalized
    raise RuntimeError("Graph is finalized and cannot be modified.")
RuntimeError: Graph is finalized and cannot be modified.

代码细节是:

class SetTensor(session_run_hook.SessionRunHook):
    """ like tf.train.LoggingTensorHook  """        
    def after_create_session(self, session, coord):
        """ Called when new TensorFlow session is created: graph is finalized and ops can no longer be added.  """
        graph = tf.get_default_graph()
        ti = graph.get_tensor_by_name("h_1_15/bias:0")
        with session.as_default():
            with tf.name_scope("rewrite"):
                saver = tf.train.Saver([ti])    # TODO: ERROR INFO:  Graph is finalized and cannot be modified.
                saver.restore(session, "/Users/zhouliaoming/data/credit_dnn/model_retrain/rm_gene_v2_sall/model.ckpt-2102")
        pass        

def main(unused_argv):
    """ train """
    norm_all_func = lambda x:  tf.cond(x>1, lambda: tf.log(x), lambda: tf.identity(x))
    feature_columns=[[tf.feature_column.numeric_column(COLUMNS[i], shape=fi, normalizer_fn=lambda x: tf.py_func(weight_norm2, [x], tf.float32) )] for i, fi in enumerate(FEA_DIM)]  # normlized: running OK!
    ## use self-defined model
    param = {"learning_rate": 0.0001, "feature_columns": feature_columns, "isanalysis": FLAGS.isanalysis, "isall": False}
    clf_ = tf.estimator.Estimator(model_fn=model_fn_wide2deep, params=param, model_dir=ckpt_dir)
    hook_test = SetTensor(["h_1_15/bias", "h_1_15/kernel"])
    epochs_per_eval = 1
    for n in range(int(FLAGS.num_epochs/epochs_per_eval)):
        # train num_epochs
        clf_.train(input_fn=lambda: read_file([tables[0]], epochs_per_eval), steps=None, hooks=[hook_test])     # input yield: x, y

推荐答案

SessionRunHook 不适用于此用例.正如错误所说,一旦 sess.run() 被调用,您就无法更改图表.

SessionRunHook is not meant for this use case. As the error says, you cannot change the graph once sess.run() has been invoked.

您可以在普通代码"中使用 saver.restore() 分配变量.你不必在任何钩子里面.

You can assign variables using saver.restore() in your "normal code". You don't have to be inside any hooks.

此外,如果您想恢复许多变量并可以在检查点中将它们与它们的名称和形状进行匹配,您可能需要查看 https://gist.github.com/iganichev/d2d8a0b1abc6b15d4a07de83171163d4.它显示了一些用于恢复变量子集的示例代码.

Also, if you want to restore many variables and can match them to their names and shapes in a checkpoint, you might want to take a look at https://gist.github.com/iganichev/d2d8a0b1abc6b15d4a07de83171163d4. It shows some example code to restore a subset of variables.

这篇关于如何在 SessionRunHook 中使用 tf.train.Saver?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆