tensorflow:保存和恢复会话 [英] tensorflow: saving and restoring session

查看:68
本文介绍了tensorflow:保存和恢复会话的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试实施来自答案的建议:Tensorflow:如何保存/恢复模型?

I am trying to implement a suggestion from answers: Tensorflow: how to save/restore a model?

我有一个以 sklearn 样式包装 tensorflow 模型的对象.

I have an object which wraps a tensorflow model in a sklearn style.

import tensorflow as tf
class tflasso():
    saver = tf.train.Saver()
    def __init__(self,
                 learning_rate = 2e-2,
                 training_epochs = 5000,
                    display_step = 50,
                    BATCH_SIZE = 100,
                    ALPHA = 1e-5,
                    checkpoint_dir = "./",
             ):
        ...

    def _create_network(self):
       ...


    def _load_(self, sess, checkpoint_dir = None):
        if checkpoint_dir:
            self.checkpoint_dir = checkpoint_dir

        print("loading a session")
        ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            self.saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            raise Exception("no checkpoint found")
        return

    def fit(self, train_X, train_Y , load = True):
        self.X = train_X
        self.xlen = train_X.shape[1]
        # n_samples = y.shape[0]

        self._create_network()
        tot_loss = self._create_loss()
        optimizer = tf.train.AdagradOptimizer( self.learning_rate).minimize(tot_loss)

        # Initializing the variables
        init = tf.initialize_all_variables()
        " training per se"
        getb = batchgen( self.BATCH_SIZE)

        yvar = train_Y.var()
        print(yvar)
        # Launch the graph
        NUM_CORES = 3  # Choose how many cores to use.
        sess_config = tf.ConfigProto(inter_op_parallelism_threads=NUM_CORES,
                                                           intra_op_parallelism_threads=NUM_CORES)
        with tf.Session(config= sess_config) as sess:
            sess.run(init)
            if load:
                self._load_(sess)
            # Fit all training data
            for epoch in range( self.training_epochs):
                for (_x_, _y_) in getb(train_X, train_Y):
                    _y_ = np.reshape(_y_, [-1, 1])
                    sess.run(optimizer, feed_dict={ self.vars.xx: _x_, self.vars.yy: _y_})
                # Display logs per epoch step
                if (1+epoch) % self.display_step == 0:
                    cost = sess.run(tot_loss,
                            feed_dict={ self.vars.xx: train_X,
                                    self.vars.yy: np.reshape(train_Y, [-1, 1])})
                    rsq =  1 - cost / yvar
                    logstr = "Epoch: {:4d}\tcost = {:.4f}\tR^2 = {:.4f}".format((epoch+1), cost, rsq)
                    print(logstr )
                    self.saver.save(sess, self.checkpoint_dir + 'model.ckpt',
                       global_step= 1+ epoch)

            print("Optimization Finished!")
        return self

当我跑步时:

tfl = tflasso()
tfl.fit( train_X, train_Y , load = False)

我得到输出:

Epoch:   50 cost = 38.4705  R^2 = -1.2036
    b1: 0.118122
Epoch:  100 cost = 26.4506  R^2 = -0.5151
    b1: 0.133597
Epoch:  150 cost = 22.4330  R^2 = -0.2850
    b1: 0.142261
Epoch:  200 cost = 20.0361  R^2 = -0.1477
    b1: 0.147998

但是,当我尝试恢复参数时(即使没有杀死对象):tfl.fit(train_X, train_Y, load = True)

However, when I try to recover the parameters (even without killing the object): tfl.fit( train_X, train_Y , load = True)

我得到了奇怪的结果.首先,加载的值与保存的值不对应.

I get strange results. First of all, the loaded value does not correspond the saved one.

loading a session
loaded b1: 0.1          <------- Loaded another value than saved
Epoch:   50 cost = 30.8483  R^2 = -0.7670
    b1: 0.137484  

加载的正确方法是什么,可能首先检查保存的变量?

What is the right way to load, and probably first inspect the saved variables?

推荐答案

TL;DR: 你应该尝试重新编写这个类,以便 self.create_network()(i) 只调用一次,并且 (ii) 在 tf.train.Saver() 构造之前调用.

TL;DR: You should try to rework this class so that self.create_network() is called (i) only once, and (ii) before the tf.train.Saver() is constructed.

这里有两个微妙的问题,这是由于代码结构和 tf.train.Saver 构造函数.当您构造一个没有参数的保护程序(如您的代码中)时,它会收集您程序中的当前变量集,并将操作添加到图形中以保存和恢复它们.在你的代码中,当你调用tflasso() 时,它会构造一个saver,并且不会有任何变量(因为create_network() 还没有被调用).结果,检查点应该是空的.

There are two subtle issues here, which are due to the code structure, and the default behavior of the tf.train.Saver constructor. When you construct a saver with no arguments (as in your code), it collects the current set of variables in your program, and adds ops to the graph for saving and restoring them. In your code, when you call tflasso(), it will construct a saver, and there will be no variables (because create_network() has not yet been called). As a result, the checkpoint should be empty.

第二个问题是——默认情况下——保存的检查点的格式是来自变量的name 属性 为其当前值.如果您创建两个同名变量,它们将被 TensorFlow 自动统一化":

The second issue is that—by default—the format of a saved checkpoint is a map from the name property of a variable to its current value. If you create two variables with the same name, they will be automatically "uniquified" by TensorFlow:

v = tf.Variable(..., name="weights")
assert v.name == "weights"
w = tf.Variable(..., name="weights")
assert v.name == "weights_1"  # The "_1" is added by TensorFlow.

这样做的结果是,当您在第二次调用 tfl.fit() 时调用 self.create_network() 时,变量都会有不同的名称来自存储在检查点中的名称——或者如果保护程序是在网络之后构建的.(您可以通过将 name-Variable 字典传递给保护程序构造函数来避免这种行为,但这通常很尴尬.)

The consequence of this is that, when you call self.create_network() in the second call to tfl.fit(), the variables will all have different names from the names that are stored in the checkpoint—or would have been if the saver had been constructed after the network. (You can avoid this behavior by passing a name-Variable dictionary to the saver constructor, but this is usually quite awkward.)

有两种主要的解决方法:

There are two main workarounds:

  1. 在每次调用 tflasso.fit() 时,通过定义新的 tf.Graph 重新创建整个模型,然后在该图中构建网络并创建一个 tf.train.Saver.

  1. In each call to tflasso.fit(), create the whole model afresh, by defining a new tf.Graph, then in that graph building the network and creating a tf.train.Saver.

RECOMMENDED 创建网络,然后在 tflasso 构造函数中创建 tf.train.Saver,并在每个调用 tflasso.fit().请注意,您可能需要做更多的工作来重新组织事物(特别是,我不确定您对 self.Xself.xlen 做了什么)但它应该可以通过 占位符 和喂食来实现这一点.

RECOMMENDED Create the network, then the tf.train.Saver in the tflasso constructor, and reuse this graph on each call to tflasso.fit(). Note that you might need to do some more work to reorganize things (in particular, I'm not sure what you do with self.X and self.xlen) but it should be possible to achieve this with placeholders and feeding.

这篇关于tensorflow:保存和恢复会话的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆