在恢复的TensorFlow模型中训练一些变量 [英] Training some variables in a restored TensorFlow model

查看:91
本文介绍了在恢复的TensorFlow模型中训练一些变量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个TensorFlow(TF)模型,我想恢复并重新训练其中的一些参数。我知道 tf.get_operation_by_name(优化程序的名称)会在存储模型之前检索用于训练模型的原始优化程序。但是,我不知道如何传递要优化器重新训练的TF变量的新列表!

I have a TensorFlow (TF) model that I'd like to restore and retrain some of its parameters. I know that tf.get_operation_by_name('name of the optimizer') retrieves the original optimizer that was used to train the model before it was stored. However, I don't know how to pass the new list of TF variables that I want the optimizer to retrain!

此示例有助于说明我想做的事情:

This example helps illustrate what I want to do:

learning_rate = 0.0001
training_iters = 60000
batch_size = 64
display_step = 20
ImVecDim = 784# The number of elements in a an image vector (flattening a  28x28 2D image)
NumOfClasses = 10
dropout = 0.8

with tf.Session() as sess:
   LoadMod = tf.train.import_meta_graph('simple_mnist.ckpt.meta')  # This object loads the model
   LoadMod.restore(sess, tf.train.latest_checkpoint('./')) # Loading weights and biases and other stuff to the model
   g = tf.get_default_graph()

   # Variables to be retrained:
   wc2 = g.get_tensor_by_name('wc2:0')
   bc2 = g.get_tensor_by_name('bc2:0')
   wc3 = g.get_tensor_by_name('wc3:0')
   bc3 = g.get_tensor_by_name('bc3:0')
   wd1 = g.get_tensor_by_name('wd1:0')
   bd1 = g.get_tensor_by_name('bd1:0')
   wd2 = g.get_tensor_by_name('wd2:0')
   bd2 = g.get_tensor_by_name('bd2:0')
   out_w = g.get_tensor_by_name('out_w:0')
   out_b = g.get_tensor_by_name('out_b:0')
   VarToTrain = [wc2,wc3,wd1,wd2,out_w,bc2,bc3,bd1,bd2,out_b]

   # Retrieving the optimizer:
   Opt = tf.get_operation_by_name('Adam')

  # Retraining:
  X = g.get_tensor_by_name('ImageIn:0')
  Y = g.get_tensor_by_name('LabelIn:0')
  KP = g.get_tensor_by_name('KeepProb:0')
  accuracy = g.get_tensor_by_name('NetAccuracy:0')
  cost = g.get_tensor_by_name('loss:0')
  step = 1
  while step * batch_size < training_iters:
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
  #########################################################################
  #     Here I want to pass (VarToTrain) to the optimizer (Opt)!          #
  #########################################################################          
            if step % display_step == 0:
                acc = sess.run(accuracy, feed_dict={X: batch_xs, Y: batch_ys, KP: 1.})
                loss = sess.run(cost, feed_dict={X: batch_xs, Y: batch_ys, KP: 1.})
                print("Iter " + str(step * batch_size) + ", Minibatch Loss= " + "{:.6f}".format(
                    loss) + ", Training Accuracy= " + "{:.5f}".format(acc))
            step += 1
  feed_dict = {X: mnist.test.images[:256], Y: mnist.test.labels[:256], KP: 1.0}
  ModelAccuracy = sess.run(accuracy, feed_dict)
  print('Retraining finished'+', Test Accuracy = %f' %ModelAccuracy)


推荐答案

好吧,我还没有找到一种方法来做自己想做的事情,但是我找到了一种方法解决问题我没有将新的变量列表传递给原始优化器,而是定义了一个新的优化器,并将这些变量传递给其 minimize()方法。代码如下:

Well, I have not figured out a way to do what I want exactly, but I've found a way around the problem; instead of passing a new list of variables to the original optimizer, I defined a new optimizer with those variables passed to its minimize() method. The code is given below:

learning_rate = 0.0001
training_iters = 60000
batch_size = 64
display_step = 20
ImVecDim = 784# The number of elements in a an image vector (flattening a  28x28 2D image)
NumOfClasses = 10
dropout = 0.8

with tf.Session() as sess:
   LoadMod = tf.train.import_meta_graph('simple_mnist.ckpt.meta')  # This object loads the model
   LoadMod.restore(sess, tf.train.latest_checkpoint('./')) # Loading weights and biases and other stuff to the model
   g = tf.get_default_graph()
# Retraining:
  X = g.get_tensor_by_name('ImageIn:0')
  Y = g.get_tensor_by_name('LabelIn:0')
  KP = g.get_tensor_by_name('KeepProb:0')
  accuracy = g.get_tensor_by_name('NetAccuracy:0')
  cost = g.get_tensor_by_name('loss:0')

######################## Producing a list and defining a new optimizer ####################################
  VarToTrain = g.get_collection_ref('trainable__variables')
  del VarToTrain[0] # Removing a variable from the list
  del VarToTrain[5] # Removing another variable from the list
  optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).\
                minimize(cost,var_list= VarToTrain)
##########################################################################################
  step = 1
  while step * batch_size < training_iters:
      batch_xs, batch_ys = mnist.train.next_batch(batch_size)
      sess.run(optimizer, feed_dict={X: batch_xs, Y: batch_ys, KP: dropout})
      if step % display_step == 0:
            acc = sess.run(accuracy, feed_dict={X: batch_xs, Y: batch_ys, KP: 1.})
            loss = sess.run(cost, feed_dict={X: batch_xs, Y: batch_ys, KP: 1.})
            print("Iter " + str(step * batch_size) + ", Minibatch Loss= " + "{:.6f}".format(
                        loss) + ", Training Accuracy= " + "{:.5f}".format(acc))
            step += 1
   feed_dict = {X: mnist.test.images[:256], Y: mnist.test.labels[:256], KP: 1.0}
   ModelAccuracy = sess.run(accuracy, feed_dict)
   print('Retraining finished'+', Test Accuracy = %f' %ModelAccuracy)

上面的代码完成了工作,但是有一些问题!首先,由于某种原因,每次定义与原始优化器 tf.train.AdamOtimizer()相似的优化器时,都会收到错误消息。 tf.train.GradientDescentOptimizer()是我唯一没有TF抛出错误消息的优化器。该解决方案中的另一个问题是它的不便。为了生成要训练的变量的列表,我首先必须使用 VarToTrain = g.get_collection_ref('trainable_variables')生成所有可训练变量的列表,打印出来,记住不需要的变量在列表中的位置,然后使用 del 方法将它们一一删除!必须有一种更优雅的方式来做到这一点。我所做的工作仅适用于小型网络!

The code above did the job, but it has some issues! First, for some reason, I keep getting error messages every time I define a similar optimizer to the original one, tf.train.AdamOtimizer(). The only optimizer that I can define without TF throwing me error messages is the tf.train.GradientDescentOptimizer(). The other issue in this solution is its inconvenience; in order to produce a list of the variables I want to train, I first have to produce a list of all trainable variables using VarToTrain = g.get_collection_ref('trainable_variables'), print them out, memorize the location of the unwanted variables in the list, then, delete them one by one using del method!! There must be a more elegant way to doing that. What I have done works fine only for small networks!!

这篇关于在恢复的TensorFlow模型中训练一些变量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆