如何从 Tensorflow 中的 .pb 模型获取权重 [英] How to get weights from .pb model in Tensorflow

查看:61
本文介绍了如何从 Tensorflow 中的 .pb 模型获取权重的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我训练了一个模型,然后通过冻结该模型创建了一个 .pb 文件.所以,我的问题是如何从 .pb 文件中获取权重,或者我必须做更多的过程来获取权重

I trained one model and then create one .pb file by freeze that model. so, my question is how to get weights from .pb file or i have to do more process for get weights

@mrry,请指导我.

@mrry, please guide me.

推荐答案

让我们先从 .pb 文件中加载图表.

Let us first load the graph from .pb file.

import tensorflow as tf
from tensorflow.python.platform import gfile

GRAPH_PB_PATH = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb' #path to your .pb file
with tf.Session(config=config) as sess:
  print("load graph")
  with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
    graph_nodes=[n for n in graph_def.node]

现在,当您将图形冻结到 .pb 文件时,您的变量将转换为 Const 类型,并且作为训练变量的权重也将存储为 Const.pb 文件中.graph_nodes 包含图中的所有节点.但是我们对所有 Const 类型的节点感兴趣.

Now when you freeze a graph to .pb file your variables are converted to Const type and the weights which were trainabe variables would also be stored as Const in .pb file. graph_nodes contains all the nodes in graph. But we are interested in all the Const type nodes.

wts = [n for n in graph_nodes if n.op=='Const']

wts 的每个元素都是 NodeDef 类型.它有几个属性,如名称、操作等.可以按如下方式提取值 -

Each element of wts is of NodeDef type. It has several atributes such as name, op etc. The values can be extracted as follows -

from tensorflow.python.framework import tensor_util

for n in wts:
    print "Name of the node - %s" % n.name
    print "Value - " 
    print tensor_util.MakeNdarray(n.attr['value'].tensor)

希望这能解决您的疑虑.

Hope this solves your concern.

这篇关于如何从 Tensorflow 中的 .pb 模型获取权重的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆