如何使用 Tensorflow 2.2 pb 文件运行推理? [英] How to run inference using Tensorflow 2.2 pb file?
问题描述
我关注了网站:https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/但是,我仍然不知道如何使用 frozen_func 运行推理(请参阅下面的代码).请告知如何在 TensorFlow 2.2 中使用 pb 文件运行推理.谢谢.
I followed the website: https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/ However, I still do not know how to run inference with frozen_func(see my code below). Please advise how to run inference using pb file in TensorFlow 2.2. Thanks.
import tensorflow as tf
def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False):
def _imports_graph_def():
tf.compat.v1.import_graph_def(graph_def, name="")
wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
import_graph = wrapped_import.graph
print("-" * 50)
print("Frozen model layers: ")
layers = [op.name for op in import_graph.get_operations()]
if print_graph == True:
for layer in layers:
print(layer)
print("-" * 50)
return wrapped_import.prune(
tf.nest.map_structure(import_graph.as_graph_element, inputs),
tf.nest.map_structure(import_graph.as_graph_element, outputs))
# Load frozen graph using TensorFlow 1.x functions
with tf.io.gfile.GFile("/content/drive/My Drive/Model_file/froze_graph.pb", "rb") as f:
graph_def = tf.compat.v1.GraphDef()
loaded = graph_def.ParseFromString(f.read())
# Wrap frozen graph to ConcreteFunctions
frozen_func = wrap_frozen_graph(graph_def=graph_def,
inputs=["wav_data:0"],
outputs=["labels_softmax:0"],
print_graph=True)
推荐答案
您可以使用 tf.graph_util.import_graph_def
在 tf.function
来做到这一点.例如,假设您制作了一个测试 GraphDef 文件 my_func.pb
,如下所示:
You can use tf.graph_util.import_graph_def
inside a tf.function
to do that. For example, suppose you make a test GraphDef file my_func.pb
like this:
import tensorflow as tf
# Test function to make into a GraphDef file
@tf.function
def my_func(x):
return tf.square(x, name='y')
# Get graph
g = my_func.get_concrete_function(tf.TensorSpec(None, tf.float32)).graph
# Write to file
tf.io.write_graph(g, '.', 'my_func.pb', as_text=False)
然后您可以加载它并像这样使用它:
You can then load it and use it like this:
import tensorflow as tf
from tensorflow.core.framework.graph_pb2 import GraphDef
# Load GraphDef
with open('my_func.pb', 'rb') as f:
gd = GraphDef()
gd.ParseFromString(f.read())
@tf.function
def my_func2(x):
# Ensure the input is a tensor of the right type
x = tf.convert_to_tensor(x, tf.float32)
# Import the graph giving x as input and getting the output y
y = tf.graph_util.import_graph_def(
gd, input_map={'x:0': x}, return_elements=['y:0'])[0]
return y
tf.print(my_func2(2))
# 4
这篇关于如何使用 Tensorflow 2.2 pb 文件运行推理?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!