TensorFlow检查点保存并读取 [英] TensorFlow checkpoint save and read
问题描述
我有一个基于TensorFlow的神经网络和一组变量。
I have a TensorFlow based neural network and a set of variables.
培训功能如下:
def train(load = True, step)
"""
Defining the neural network is skipped here
"""
train_step = tf.train.AdamOptimizer(1e-4).minimize(mse)
# Saver
saver = tf.train.Saver()
if not load:
# Initalizing variables
sess.run(tf.initialize_all_variables())
else:
saver.restore(sess, 'Variables/map.ckpt')
print 'Model Restored!'
# Perform stochastic gradient descent
for i in xrange(step):
train_step.run(feed_dict = {x: train, y_: label})
# Save model
save_path = saver.save(sess, 'Variables/map.ckpt')
print 'Model saved in file: ', save_path
print 'Training Done!'
我正在调用这样的训练函数:
I was calling the training function like this:
# First train
train(False, 1)
# Following train
for i in xrange(10):
train(True, 10)
我做过这种培训,因为我需要将不同的数据集提供给我的模型。但是,如果我以这种方式调用train函数,TensorFlow将生成错误消息,指示它无法从文件中读取已保存的模型。
I did this kind of training because I needed to feed different dataset to my model. However, if I call the train function in this way TensorFlow will generate error message indicating that it cannot read the saved model from file.
经过一些实验,我发现发生了这种情况因为检查点保存很慢。在将文件写入磁盘之前,下一个列车功能将开始读取,从而产生错误。
After some experiments I found that this happened because the checkpoint saving was slow. Before the file was written to the disk the next train function would start reading, thus generate the error.
我试图使用time.sleep()函数来制作一些每次通话之间的延迟但它没有用。
I have tried to use time.sleep() function to make some delay between each call but it didn't work.
任何人都知道如何解决这种写/读错误?非常感谢!
Anyone knows how to work out this kind of write/read error? Thank you very much!
推荐答案
您的代码中存在一个微妙的问题:每次调用 train()
函数,对于所有模型变量和神经网络的其余部分,更多节点被添加到同一TensorFlow图中。这意味着每次构造 tf.train.Saver()
时,它都包含之前调用 train()<的所有变量/ code>。每次重新创建模型时,都会使用额外的
_N
后缀创建变量,以便为它们提供唯一的名称:
There is a subtle issue in your code: each time you call the train()
function, more nodes are added to the same TensorFlow graph, for all the model variables and the rest of the neural network. This means that each time you construct a tf.train.Saver()
, it includes all of the variables for the previous calls to train()
. Each time you recreate the model, the variables are created with an extra _N
suffix to give them a unique name:
- 使用变量构建的保护程序
var_a
,var_b
。 - 使用变量构建的保护程序
var_a
,var_b
,var_a_1
,var_b_1
。 - 使用变量构建的保护程序
var_a
,var_b
,var_a_1
,var_b_1
,var_a_2
,var_b_2
。 - 等。
- Saver constructed with variables
var_a
,var_b
. - Saver constructed with variables
var_a
,var_b
,var_a_1
,var_b_1
. - Saver constructed with variables
var_a
,var_b
,var_a_1
,var_b_1
,var_a_2
,var_b_2
. - etc.
tf.train.Saver
的默认行为是将每个变量与相应op的名称相关联。这意味着 var_a_1
将不会从 var_a
初始化,因为它们最终会有不同的名称。
The default behavior for a tf.train.Saver
is to associate each variable with the name of the corresponding op. This means that var_a_1
won't be initialized from var_a
, because they end up with different names.
解决方法是每次调用 train()
时创建一个新图表。最简单的解决方法是更改主程序,为每次调用 train()
创建一个新图表,如下所示:
The solution is to create a new graph each time you call train()
. The simplest way to fix it is to change your main program to create a new graph for each call to train()
as follows:
# First train
with tf.Graph().as_default():
train(False, 1)
# Following train
for i in xrange(10):
with tf.Graph().as_default():
train(True, 10)
...或者,相当于你可以在<$ c中移动和
块$ c> train() function。
...or, equivalently, you could move the with
block inside the train()
function.
这篇关于TensorFlow检查点保存并读取的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!