在TensorFlow中修改已恢复的CNN模型的权重和偏差 [英] Modifying the weights and biases of a restored CNN model in TensorFlow

查看:565
本文介绍了在TensorFlow中修改已恢复的CNN模型的权重和偏差的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我最近开始使用TensorFlow(TF),遇到一个需要帮助的问题.基本上,我已经恢复了预训练的模型,在重新测试其准确性之前,需要修改其一层的权重和偏差.现在,我的问题如下: 如何使用TF中的assign方法更改权重和偏差?甚至可以在TF中修改已还原模型的权重吗?

I have recently started using TensorFlow (TF), and I have come across a problem that I need some help with. Basically, I've restored a pre-trained model, and I need to modify the weights and biases of one of its layers before I retest its accuracy. Now, my problem is the following: how can I change the weights and biases using the assign method in TF? Is modifying the weights of a restored modeled even possible in TF?

这是我的代码:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # Imports the MINST dataset

# Data Set:
# ---------
mnist = input_data.read_data_sets("/home/frr/MNIST_data", one_hot=True)# An object where data is stored

ImVecDim = 784# The number of elements in a an image vector (flattening a 28x28 2D image)
NumOfClasses = 10

g = tf.get_default_graph()

with tf.Session() as sess:
  LoadMod = tf.train.import_meta_graph('simple_mnist.ckpt.meta')  # This object loads the model
  LoadMod.restore(sess, tf.train.latest_checkpoint('./'))# Loading weights and biases and other stuff to the model

  # ( Here I'd like to modify the weights and biases of layer 1, set them to one for example, before I go ahead and test the accuracy ) #

  # Testing the acuracy of the model:
  X = g.get_tensor_by_name('ImageIn:0')
  Y = g.get_tensor_by_name('LabelIn:0')
  KP = g.get_tensor_by_name('KeepProb:0')
  Accuracy = g.get_tensor_by_name('NetAccuracy:0')
  feed_dict = { X: mnist.test.images[:256], Y: mnist.test.labels[:256], KP: 1.0 }
  print( 'Model Accuracy = ' )
  print( sess.run( Accuracy, feed_dict ) )

推荐答案

除了现有答案外,还可以通过

In addition to an existing answer, tensor update can be performed via tf.assign function.

v1 = sess.graph.get_tensor_by_name('v1:0')
print(sess.run(v1))   # 1.0
sess.run(tf.assign(v1, v1 + 1))
print(sess.run(v1))   # 2.0

这篇关于在TensorFlow中修改已恢复的CNN模型的权重和偏差的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆