从Tensorflow模型获取权重 [英] Get weights from tensorflow model

查看:114
本文介绍了从Tensorflow模型获取权重的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

你好,我想从张量流微调VGG模型.我有两个问题.

Hello I would like to finetune VGG model from tensorflow. I have two questions.

如何从网络获取权重? trainable_variables为我返回空列表.

How to get the weights from network? The trainable_variables returns empty list for me.

我从这里使用现有模型: https://github.com/ry/tensorflow-vgg16 . 我找到有关获取权重的帖子,但是由于import_graph_def,这对我不起作用. 在TensorFlow训练的模型

I used existing model from here: https://github.com/ry/tensorflow-vgg16 . I find the post about getting weights however this doesn't work for me because of import_graph_def. Get the value of some weights in a model trained by TensorFlow

import tensorflow as tf
import PIL.Image
import numpy as np

with open("../vgg16.tfmodel", mode='rb') as f:
  fileContent = f.read()

graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)

images = tf.placeholder("float", [None, 224, 224, 3])

tf.import_graph_def(graph_def, input_map={ "images": images })
print("graph loaded from disk")

graph = tf.get_default_graph()

cat = np.asarray(PIL.Image.open('../cat224.jpg'))
print(cat.shape)
init = tf.initialize_all_variables()

with tf.Session(graph=graph) as sess:
  print(tf.trainable_variables() )
  sess.run(init)

推荐答案

预训练的VGG-16模型将所有模型参数编码为 tf.constant() . (例如,请参见tf.constant() 此处.)结果,模型参数将不会出现在tf.trainable_variables()中,并且在没有进行大量外科手术的情况下该模型也不可变:您需要用

This pretrained VGG-16 model encodes all of the model parameters as tf.constant() ops. (See, for example, the calls to tf.constant() here.) As a result, the model parameters would not appear in tf.trainable_variables(), and the model is not mutable without substantial surgery: you would need to replace the constant nodes with tf.Variable objects that start with the same value in order to continue training.

通常,在导入图进行再训练时, tf.train.import_meta_graph() 函数,因为该函数会加载其他元数据(包括变量的集合). tf.import_graph_def() 函数是较低级别的,并且不会填充这些集合.

In general, when importing a graph for retraining, the tf.train.import_meta_graph() function should be used, as this function loads additional metadata (including the collections of variables). The tf.import_graph_def() function is lower level, and does not populate these collections.

这篇关于从Tensorflow模型获取权重的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆