在 TensorFlow 中重命名已保存模型的变量范围 [英] Rename variable scope of saved model in TensorFlow
问题描述
是否可以在 tensorflow 中重命名给定模型的变量范围?
Is it possible to rename the variable scope of a given model in tensorflow?
例如,我根据教程为 MNIST 数字创建了逻辑回归模型:
For instance, I created a logistic regression model for MNIST digits, based on the tutorial:
with tf.variable_scope('my-first-scope'):
NUM_IMAGE_PIXELS = 784
NUM_CLASS_BINS = 10
x = tf.placeholder(tf.float32, shape=[None, NUM_IMAGE_PIXELS])
y_ = tf.placeholder(tf.float32, shape=[None, NUM_CLASS_BINS])
W = tf.Variable(tf.zeros([NUM_IMAGE_PIXELS,NUM_CLASS_BINS]))
b = tf.Variable(tf.zeros([NUM_CLASS_BINS]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
saver = tf.train.Saver([W, b])
... # some training happens
saver.save(sess, 'my-model')
现在我想在 'my-first-scope'
变量范围内重新加载保存的模型,然后再次将所有内容保存到一个新文件中,并在 'my 的新变量范围内-second-scope'
.
Now I want to reload the saved model in the 'my-first-scope'
variable scope and then save everything again to a new file and under a new variable scope of 'my-second-scope'
.
推荐答案
您可以使用 tf.contrib.framework.list_variables
和 tf.contrib.framework.load_variable
如下实现你的目标:
You can use tf.contrib.framework.list_variables
and tf.contrib.framework.load_variable
as follows to achieve your goal :
with tf.Graph().as_default(), tf.Session().as_default() as sess:
with tf.variable_scope('my-first-scope'):
NUM_IMAGE_PIXELS = 784
NUM_CLASS_BINS = 10
x = tf.placeholder(tf.float32, shape=[None, NUM_IMAGE_PIXELS])
y_ = tf.placeholder(tf.float32, shape=[None, NUM_CLASS_BINS])
W = tf.Variable(tf.zeros([NUM_IMAGE_PIXELS,NUM_CLASS_BINS]))
b = tf.Variable(tf.zeros([NUM_CLASS_BINS]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
saver = tf.train.Saver([W, b])
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
vars = tf.contrib.framework.list_variables('.')
with tf.Graph().as_default(), tf.Session().as_default() as sess:
new_vars = []
for name, shape in vars:
v = tf.contrib.framework.load_variable('.', name)
new_vars.append(tf.Variable(v, name=name.replace('my-first-scope', 'my-second-scope')))
saver = tf.train.Saver(new_vars)
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-new-model')
这篇关于在 TensorFlow 中重命名已保存模型的变量范围的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!