如何从python中的.pb文件恢复Tensorflow模型? [英] How to restore Tensorflow model from .pb file in python?
问题描述
我有一个tensorflow .pb文件,我想将其加载到python DNN中,恢复图形并获得预测.我这样做是为了测试创建的.pb文件是否可以做出与普通Saver.save()模型相似的预测.
I have an tensorflow .pb file which I would like to load into python DNN, restore the graph and get the predictions. I am doing this to test out whether the .pb file created can make the predictions similar to the normal Saver.save() model.
我的基本问题是,当我使用上述.pb文件在Android上进行预测时,得到的预测值有很大不同
My basic problem is am getting a very different value of predictions when I make them on Android using the above mentioned .pb file
我的.pb文件创建代码:
My .pb file creation code:
frozen_graph = tf.graph_util.convert_variables_to_constants(
session,
session.graph_def,
['outputLayer/Softmax']
)
with open('frozen_model.pb', 'wb') as f:
f.write(frozen_graph.SerializeToString())
所以我有两个主要问题:
So I have two major concerns:
- 如何将上述.pb文件加载到python Tensorflow模型中?
- 为什么我在python和android中获得的预测值完全不同?
推荐答案
以下代码将读取模型并打印出图中节点的名称.
The following code will read the model and print out the names of the nodes in the graph.
import tensorflow as tf
from tensorflow.python.platform import gfile
GRAPH_PB_PATH = './frozen_model.pb'
with tf.Session() as sess:
print("load graph")
with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
graph_nodes=[n for n in graph_def.node]
names = []
for t in graph_nodes:
names.append(t.name)
print(names)
您正确冻结了图形,这就是为什么您得到不同结果的原因,基本上权重没有存储在模型中.您可以使用 freeze_graph.py (
You are freezing the graph properly that is why you are getting different results basically weights are not getting stored in your model. You can use the freeze_graph.py (link) for getting a correctly stored graph.
这篇关于如何从python中的.pb文件恢复Tensorflow模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!