TensorFlow检查点保存并读取 [英] TensorFlow checkpoint save and read

查看:224
本文介绍了TensorFlow检查点保存并读取的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个基于TensorFlow的神经网络和一组变量。

I have a TensorFlow based neural network and a set of variables.

培训功能如下:

def train(load = True, step)
    """
    Defining the neural network is skipped here
    """

    train_step = tf.train.AdamOptimizer(1e-4).minimize(mse)
    # Saver
    saver = tf.train.Saver()

    if not load:
        # Initalizing variables
        sess.run(tf.initialize_all_variables())
    else:
        saver.restore(sess, 'Variables/map.ckpt')
        print 'Model Restored!'

    # Perform stochastic gradient descent
    for i in xrange(step):
        train_step.run(feed_dict = {x: train, y_: label})

    # Save model
    save_path = saver.save(sess, 'Variables/map.ckpt')
    print 'Model saved in file: ', save_path
    print 'Training Done!'

我正在调用这样的训练函数:

I was calling the training function like this:

# First train
train(False, 1)
# Following train
for i in xrange(10):
    train(True, 10)

我做过这种培训,因为我需要将不同的数据集提供给我的模型。但是,如果我以这种方式调用train函数,TensorFlow将生成错误消息,指示它无法从文件中读取已保存的模型。

I did this kind of training because I needed to feed different dataset to my model. However, if I call the train function in this way TensorFlow will generate error message indicating that it cannot read the saved model from file.

经过一些实验,我发现发生了这种情况因为检查点保存很慢。在将文件写入磁盘之前,下一个列车功能将开始读取,从而产生错误。

After some experiments I found that this happened because the checkpoint saving was slow. Before the file was written to the disk the next train function would start reading, thus generate the error.

我试图使用time.sleep()函数来制作一些每次通话之间的延迟但它没有用。

I have tried to use time.sleep() function to make some delay between each call but it didn't work.

任何人都知道如何解决这种写/读错误?非常感谢!

Anyone knows how to work out this kind of write/read error? Thank you very much!

推荐答案

您的代码中存在一个微妙的问题:每次调用 train()函数,对于所有模型变量和神经网络的其余部分,更多节点被添加到同一TensorFlow图中。这意味着每次构造 tf.train.Saver()时,它都包含之前调用 train()<的所有变量/ code>。每次重新创建模型时,都会使用额外的 _N 后缀创建变量,以便为它们提供唯一的名称:

There is a subtle issue in your code: each time you call the train() function, more nodes are added to the same TensorFlow graph, for all the model variables and the rest of the neural network. This means that each time you construct a tf.train.Saver(), it includes all of the variables for the previous calls to train(). Each time you recreate the model, the variables are created with an extra _N suffix to give them a unique name:


  1. 使用变量构建的保护程序 var_a var_b

  2. 使用变量构建的保护程序 var_a var_b var_a_1 var_b_1

  3. 使用变量构建的保护程序 var_a var_b var_a_1 var_b_1 var_a_2 var_b_2

  4. 等。

  1. Saver constructed with variables var_a, var_b.
  2. Saver constructed with variables var_a, var_b, var_a_1, var_b_1.
  3. Saver constructed with variables var_a, var_b, var_a_1, var_b_1, var_a_2, var_b_2.
  4. etc.

tf.train.Saver 的默认行为是将每个变量与相应op的名称相关联。这意味着 var_a_1 将不会从 var_a 初始化,因为它们最终会有不同的名称。

The default behavior for a tf.train.Saver is to associate each variable with the name of the corresponding op. This means that var_a_1 won't be initialized from var_a, because they end up with different names.

解决方法是每次调用 train()时创建一个新图表。最简单的解决方法是更改​​主程序,为每次调用 train()创建一个新图表,如下所示:

The solution is to create a new graph each time you call train(). The simplest way to fix it is to change your main program to create a new graph for each call to train() as follows:

# First train
with tf.Graph().as_default():
    train(False, 1)

# Following train
for i in xrange(10):
    with tf.Graph().as_default():
        train(True, 10)

...或者,相当于你可以在<$ c中移动块$ c> train() function。

...or, equivalently, you could move the with block inside the train() function.

这篇关于TensorFlow检查点保存并读取的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆