给定张量流模型图,如何查找输入节点和输出节点名称 [英] Given a tensor flow model graph, how to find the input node and output node names

查看:350
本文介绍了给定张量流模型图,如何查找输入节点和输出节点名称的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在Tensor flow Camera演示中使用自定义模型进行分类. 我生成了一个.pb文件(序列化的protobuf文件),并且可以显示其中包含的巨大图形. 如[ https://www中所述,将此图转换为优化图. oreilly.com/learning/tensorflow-on-android] ,可以使用以下过程:

I use custom model for classification in Tensor flow Camera Demo. I generated a .pb file (serialized protobuf file) and I could display the huge graph it contains. To convert this graph to a optimized graph, as given in [https://www.oreilly.com/learning/tensorflow-on-android], the following procedure could be used:

$ bazel-bin/tensorflow/python/tools/optimize_for_inference  \
--input=tf_files/retrained_graph.pb \
--output=tensorflow/examples/android/assets/retrained_graph.pb
--input_names=Mul \
--output_names=final_result

此处介绍如何从图形显示中查找输入名称和输出名称. 不使用专有名称时,设备会崩溃:

Here how to find the input_names and output_names from the graph display. When I dont use proper names, I get device crash:

E/TensorFlowInferenceInterface(16821): Failed to run TensorFlow inference 
with inputs:[AvgPool], outputs:[predictions]

E/AndroidRuntime(16821): FATAL EXCEPTION: inference

E/AndroidRuntime(16821): java.lang.IllegalArgumentException: Incompatible 
shapes: [1,224,224,3] vs. [32,1,1,2048]

E/AndroidRuntime(16821):     [[Node: dropout/dropout/mul = Mul[T=DT_FLOAT, 
_device="/job:localhost/replica:0/task:0/cpu:0"](dropout/dropout/div, 
dropout/dropout/Floor)]]

推荐答案

尝试一下:

运行python

>>> import tensorflow as tf
>>> gf = tf.GraphDef()
>>> gf.ParseFromString(open('/your/path/to/graphname.pb','rb').read())

然后

>>> [n.name + '=>' +  n.op for n in gf.node if n.op in ( 'Softmax','Placeholder')]

然后,您可以获得类似于以下内容的结果:

Then, you can get result similar to this:

['Mul=>Placeholder', 'final_result=>Softmax']

但是我不确定这是关于错误消息的节点名称的问题. 我猜您在加载图形文件时提供了错误的论据,或者您生成的图形文件有问题吗?

But I'm not sure it's the problem of node names regarding the error messages. I guess you provided wrong arguements when loading the graph file or your generated graph file is something wrong?

检查此部分:

E/AndroidRuntime(16821): java.lang.IllegalArgumentException: Incompatible 
shapes: [1,224,224,3] vs. [32,1,1,2048]

更新: 对不起, 如果您使用的是(重新)训练图,请尝试以下方法:

UPDATE: Sorry, if you're using (re)trained graph , then try this:

[n.name + '=>' +  n.op for n in gf.node if n.op in ( 'Softmax','Mul')]

看来,(重新)训练图将输入/输出操作名称保存为"Mul"和"Softmax",而优化和/或量化图将它们保存为占位符"和"Softmax".

It seems that (re)trained graph saves input/output op name as "Mul" and "Softmax", while optimized and/or quantized graph saves them as "Placeholder" and "Softmax".

BTW 在移动环境中使用经过重新训练的图形:

BTW, using retrained graph in mobile environment is not recommended according to Peter Warden's post: https://petewarden.com/2016/09/27/tensorflow-for-mobile-poets/ . It's better to use quantized or memmapped graph due to performance and file size issue, I couldn't find out how to load memmapped graph in android though...:( (no problem loading optimized / quantized graph in android)

这篇关于给定张量流模型图,如何查找输入节点和输出节点名称的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆