在 TensorFlow 中保存特定权重 [英] Save specific weights in TensorFlow
问题描述
在我的神经网络中,我创建了一些 tf.Variable
对象,如下所示:
In my neural network, I create some tf.Variable
objects as follows:
weights = {
'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}
如何在不保存其他变量的情况下在特定次数的迭代后保存 weights
和 biases
中的变量?
How would I save the variables in weights
and biases
after a specific number of iterations without saving other variables?
推荐答案
在 TensorFlow 中保存变量的标准方法是使用 tf.train.Saver
对象.默认情况下,它会保存问题中的所有变量(即 tf.all_variables()
),但是您可以通过将 var_list
可选参数传递给 tf.train 来选择性地保存变量.Saver
构造函数:
The standard way to save variables in TensorFlow is to use a tf.train.Saver
object. By default it saves all of the variables in your problem (i.e., the results of tf.all_variables()
), but you can save variables selectively by passing the var_list
optional argument to the tf.train.Saver
constructor:
weights = {
'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}
# Define savers for explicit subsets of the variables.
weights_saver = tf.train.Saver(var_list=weights)
biases_saver = tf.train.Saver(var_list=biases)
# ...
# You need a TensorFlow Session to save variables.
sess = tf.Session()
# ...
# ...then call the following methods as appropriate:
weights_saver.save(sess) # Save the current value of the weights.
biases_saver.save(sess) # Save the current value of the biases.
<小时>
请注意,如果您将字典传递给 tf.train.Saver
构造函数(例如 weights
和/或 biases
字典来自您的问题),TensorFlow 将使用字典键(例如 'wc1_0'
)作为它创建或使用的任何检查点文件中相应变量的名称.
Note that if you pass a dictionary to the tf.train.Saver
constructor (such as the weights
and/or biases
dictionaries from your question), TensorFlow will use the dictionary key (e.g. 'wc1_0'
) as the name for the corresponding variable in any checkpoint files it creates or consumes.
默认情况下,或者如果您将 tf.Variable
对象列表传递给构造函数,TensorFlow 将改用 tf.Variable.name
属性.
By default, or if you pass a list of tf.Variable
objects to the constructor, TensorFlow will use the tf.Variable.name
property instead.
传递字典使您能够在为每个变量提供不同 Variable.name
属性的模型之间共享检查点.仅当您想将创建的检查点用于其他模型时,此细节才重要.
Passing a dictionary gives you the ability to share checkpoints between models that give different Variable.name
properties to each variable.
This detail is only important if you want to use the created checkpoints with another model.
这篇关于在 TensorFlow 中保存特定权重的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!