在 TensorFlow 中保存特定权重 [英] Save specific weights in TensorFlow

查看:88
本文介绍了在 TensorFlow 中保存特定权重的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在我的神经网络中,我创建了一些 tf.Variable 对象,如下所示:

In my neural network, I create some tf.Variable objects as follows:

weights = {
    'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
    'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
    'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
    'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}

如何在不保存其他变量的情况下在特定次数的迭代后保存 weightsbiases 中的变量?

How would I save the variables in weights and biases after a specific number of iterations without saving other variables?

推荐答案

在 TensorFlow 中保存变量的标准方法是使用 tf.train.Saver 对象.默认情况下,它会保存问题中的所有变量(即 tf.all_variables()),但是您可以通过将 var_list 可选参数传递给 tf.train 来选择性地保存变量.Saver 构造函数:

The standard way to save variables in TensorFlow is to use a tf.train.Saver object. By default it saves all of the variables in your problem (i.e., the results of tf.all_variables()), but you can save variables selectively by passing the var_list optional argument to the tf.train.Saver constructor:

weights = {
    'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
    'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
    'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
    'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}

# Define savers for explicit subsets of the variables.
weights_saver = tf.train.Saver(var_list=weights)
biases_saver = tf.train.Saver(var_list=biases)

# ...
# You need a TensorFlow Session to save variables.
sess = tf.Session()
# ...

# ...then call the following methods as appropriate:
weights_saver.save(sess)  # Save the current value of the weights.
biases_saver.save(sess)   # Save the current value of the biases.

<小时>

请注意,如果您将字典传递给 tf.train.Saver 构造函数(例如 weights 和/或 biases 字典来自您的问题),TensorFlow 将使用字典键(例如 'wc1_0')作为它创建或使用的任何检查点文件中相应变量的名称.


Note that if you pass a dictionary to the tf.train.Saver constructor (such as the weights and/or biases dictionaries from your question), TensorFlow will use the dictionary key (e.g. 'wc1_0') as the name for the corresponding variable in any checkpoint files it creates or consumes.

默认情况下,或者如果您将 tf.Variable 对象列表传递给构造函数,TensorFlow 将改用 tf.Variable.name 属性.

By default, or if you pass a list of tf.Variable objects to the constructor, TensorFlow will use the tf.Variable.name property instead.

传递字典使您能够在为每个变量提供不同 Variable.name 属性的模型之间共享检查点.仅当您想将创建的检查点用于其他模型时,此细节才重要.

Passing a dictionary gives you the ability to share checkpoints between models that give different Variable.name properties to each variable. This detail is only important if you want to use the created checkpoints with another model.

这篇关于在 TensorFlow 中保存特定权重的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆