在 TensorFlow 中重命名已保存模型的变量范围 [英] Rename variable scope of saved model in TensorFlow

查看:29
本文介绍了在 TensorFlow 中重命名已保存模型的变量范围的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

是否可以在 tensorflow 中重命名给定模型的变量范围?

Is it possible to rename the variable scope of a given model in tensorflow?

例如,我根据教程为 MNIST 数字创建了逻辑回归模型:

For instance, I created a logistic regression model for MNIST digits, based on the tutorial:

with tf.variable_scope('my-first-scope'):
    NUM_IMAGE_PIXELS = 784
    NUM_CLASS_BINS = 10
    x = tf.placeholder(tf.float32, shape=[None, NUM_IMAGE_PIXELS])
    y_ = tf.placeholder(tf.float32, shape=[None, NUM_CLASS_BINS])

    W = tf.Variable(tf.zeros([NUM_IMAGE_PIXELS,NUM_CLASS_BINS]))
    b = tf.Variable(tf.zeros([NUM_CLASS_BINS]))

    y = tf.nn.softmax(tf.matmul(x,W) + b)
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
    saver = tf.train.Saver([W, b])

... # some training happens

saver.save(sess, 'my-model')

现在我想在 'my-first-scope' 变量范围内重新加载保存的模型,然后再次将所有内容保存到一个新文件中,并在 'my 的新变量范围内-second-scope'.

Now I want to reload the saved model in the 'my-first-scope' variable scope and then save everything again to a new file and under a new variable scope of 'my-second-scope'.

推荐答案

您可以使用 tf.contrib.framework.list_variablestf.contrib.framework.load_variable 如下实现你的目标:

You can use tf.contrib.framework.list_variables and tf.contrib.framework.load_variable as follows to achieve your goal :

with tf.Graph().as_default(), tf.Session().as_default() as sess:
  with tf.variable_scope('my-first-scope'):
    NUM_IMAGE_PIXELS = 784
    NUM_CLASS_BINS = 10
    x = tf.placeholder(tf.float32, shape=[None, NUM_IMAGE_PIXELS])
    y_ = tf.placeholder(tf.float32, shape=[None, NUM_CLASS_BINS])

    W = tf.Variable(tf.zeros([NUM_IMAGE_PIXELS,NUM_CLASS_BINS]))
    b = tf.Variable(tf.zeros([NUM_CLASS_BINS]))

    y = tf.nn.softmax(tf.matmul(x,W) + b)
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
    saver = tf.train.Saver([W, b])
  sess.run(tf.global_variables_initializer())
  saver.save(sess, 'my-model')

vars = tf.contrib.framework.list_variables('.')
with tf.Graph().as_default(), tf.Session().as_default() as sess:

  new_vars = []
  for name, shape in vars:
    v = tf.contrib.framework.load_variable('.', name)
    new_vars.append(tf.Variable(v, name=name.replace('my-first-scope', 'my-second-scope')))

  saver = tf.train.Saver(new_vars)
  sess.run(tf.global_variables_initializer())
  saver.save(sess, 'my-new-model')

这篇关于在 TensorFlow 中重命名已保存模型的变量范围的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆