如何从给定的模型获取Graph(或GraphDef)? [英] How to get Graph (or GraphDef) from a given Model?
问题描述
我有一个使用Tensorflow 2和Keras定义的大型模型. 该模型在Python中运行良好.现在,我想将其导入C ++项目.
I have a big model defined using Tensorflow 2 with Keras. The model works well in Python. Now, I want to import it into C++ project.
在我的C ++项目中,我使用TF_GraphImportGraphDef
函数.
如果我使用以下代码准备*.pb
文件,则效果很好:
Inside my C++ project, I use TF_GraphImportGraphDef
function.
It works well if I prepare *.pb
file using the following code:
with open('load_model.pb', 'wb') as f:
f.write(tf.compat.v1.get_default_graph().as_graph_def().SerializeToString())
我已经在使用Tensorflow 1(使用tf.compat.v1.*函数)编写的简单网络上尝试了此代码.效果很好.
I've tried this code on a simple network written using Tensorflow 1 (using tf.compat.v1.* functions). It works well.
现在我想将我的大模型(开头提到,使用Tensorflow 2编写)导出到C ++项目.为此,我需要从模型中获取一个Graph
或GraphDef
对象.问题是:该怎么做?我没有找到任何属性或函数来获取它.
Now I want to export my big model (mentioned at the beginning, written using Tensorflow 2) to the C++ project. To do this, I need to get a Graph
or GraphDef
object from my model. The question is: how to do this? I didn't find any property or function to get it.
我还尝试使用tf.saved_model.save(model, 'model')
保存整个模型.它会生成包含不同文件(包括saved_model.pb
文件)的目录.不幸的是,当我尝试使用TF_GraphImportGraphDef
函数在C ++中加载此文件时,程序将引发异常.
I've also tried to use tf.saved_model.save(model, 'model')
to save the whole model. It generates a directory with different files including saved_model.pb
file. Unfortunately, when I try to load this file in C++ using TF_GraphImportGraphDef
function, the program throws an exception.
推荐答案
tf.saved_model.save
不包含 GraphDef
消息,但 SavedModel
.您可以在Python中遍历SavedModel
以获得嵌入的图形,但是那将无法立即进行作为冻结的图形,因此正确处理可能会很困难.取而代之的是,C ++ API现在包括一个
The protocol buffers file generated by tf.saved_model.save
does not contain a GraphDef
message, but a SavedModel
. You could traverse that SavedModel
in Python to get the embedded graph(s) in it, but that would not immediately work as a frozen graph, so getting it right would probably be difficult. Instead of that, the C++ API now includes a LoadSavedModel
call that allows you to load a whole saved model from a directory. It should look some like this:
#include <iostream>
#include <...> // Add necessary TF include directives
using namespace std;
using namespace tensorflow;
int main()
{
// Path to saved model directory
const string export_dir = "...";
// Load model
Status s;
SavedModelBundle bundle;
SessionOptions session_options;
RunOptions run_options;
s = LoadSavedModel(session_options, run_options, export_dir,
// default "serve" tag set by tf.saved_model.save
{"serve"}, &bundle));
if (!.ok())
{
cerr << "Could not load model: " << s.error_message() << endl;
return -1;
}
// Model is loaded
// ...
return 0;
}
From here, you could do different things. Maybe you would be most comfortable converting that saved model into a frozen graph, using FreezeSavedModel
, which should allow you to do things pretty much as you were doing them before:
GraphDef frozen_graph_def;
std::unordered_set<string> inputs;
std::unordered_set<string> outputs;
s = FreezeSavedModel(bundle, &frozen_graph_def,
&inputs, &outputs));
if (!s.ok())
{
cerr << "Could not freeze model: " << s.error_message() << endl;
return -1;
}
否则,您可以直接使用保存的模型对象:
Otherwise, you can work directly with the saved model object:
// Default "serving_default" signature name set by tf.saved_model_save
const SignatureDef& signature_def = bundle.GetSignatures().at("serving_default");
// Get input and output names (different from layer names)
// Key is input and output layer names
const string input_name = signature_def.inputs().at("my_input").name();
const string output_name = signature_def.inputs().at("my_output").name();
// Run model
Tensor input = ...;
std::vector<Tensor> outputs;
s = bundle.session->Run({{input_name, input}}, {output_name}, {}, &outputs));
if (!s.ok())
{
cerr << "Error running model: " << s.error_message() << endl;
return -1;
}
// Get result
Tensor& output = outputs[0];
这篇关于如何从给定的模型获取Graph(或GraphDef)?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!