如何从给定的模型获取Graph(或GraphDef)? [英] How to get Graph (or GraphDef) from a given Model?

查看:245
本文介绍了如何从给定的模型获取Graph(或GraphDef)?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个使用Tensorflow 2和Keras定义的大型模型. 该模型在Python中运行良好.现在,我想将其导入C ++项目.

I have a big model defined using Tensorflow 2 with Keras. The model works well in Python. Now, I want to import it into C++ project.

在我的C ++项目中,我使用TF_GraphImportGraphDef函数. 如果我使用以下代码准备*.pb文件,则效果很好:

Inside my C++ project, I use TF_GraphImportGraphDef function. It works well if I prepare *.pb file using the following code:

    with open('load_model.pb', 'wb') as f:
        f.write(tf.compat.v1.get_default_graph().as_graph_def().SerializeToString())

我已经在使用Tensorflow 1(使用tf.compat.v1.*函数)编写的简单网络上尝试了此代码.效果很好.

I've tried this code on a simple network written using Tensorflow 1 (using tf.compat.v1.* functions). It works well.

现在我想将我的大模型(开头提到,使用Tensorflow 2编写)导出到C ++项目.为此,我需要从模型中获取一个GraphGraphDef对象.问题是:该怎么做?我没有找到任何属性或函数来获取它.

Now I want to export my big model (mentioned at the beginning, written using Tensorflow 2) to the C++ project. To do this, I need to get a Graph or GraphDef object from my model. The question is: how to do this? I didn't find any property or function to get it.

我还尝试使用tf.saved_model.save(model, 'model')保存整个模型.它会生成包含不同文件(包括saved_model.pb文件)的目录.不幸的是,当我尝试使用TF_GraphImportGraphDef函数在C ++中加载此文件时,程序将引发异常.

I've also tried to use tf.saved_model.save(model, 'model') to save the whole model. It generates a directory with different files including saved_model.pb file. Unfortunately, when I try to load this file in C++ using TF_GraphImportGraphDef function, the program throws an exception.

推荐答案

tf.saved_model.save 不包含 GraphDef 消息,但 SavedModel .您可以在Python中遍历SavedModel 以获得嵌入的图形,但是那将无法立即进行作为冻结的图形,因此正确处理可能会很困难.取而代之的是,C ++ API现在包括一个

The protocol buffers file generated by tf.saved_model.save does not contain a GraphDef message, but a SavedModel. You could traverse that SavedModel in Python to get the embedded graph(s) in it, but that would not immediately work as a frozen graph, so getting it right would probably be difficult. Instead of that, the C++ API now includes a LoadSavedModel call that allows you to load a whole saved model from a directory. It should look some like this:

#include <iostream>
#include <...>  // Add necessary TF include directives

using namespace std;
using namespace tensorflow;

int main()
{
    // Path to saved model directory
    const string export_dir = "...";
    // Load model
    Status s;
    SavedModelBundle bundle;
    SessionOptions session_options;
    RunOptions run_options;
    s = LoadSavedModel(session_options, run_options, export_dir,
                       // default "serve" tag set by tf.saved_model.save
                       {"serve"}, &bundle));
    if (!.ok())
    {
        cerr << "Could not load model: " << s.error_message() << endl;
        return -1;
    }
    // Model is loaded
    // ...
    return 0;
}

从这里开始,您可以做不同的事情.使用

From here, you could do different things. Maybe you would be most comfortable converting that saved model into a frozen graph, using FreezeSavedModel, which should allow you to do things pretty much as you were doing them before:

GraphDef frozen_graph_def;
std::unordered_set<string> inputs;
std::unordered_set<string> outputs;
s = FreezeSavedModel(bundle, &frozen_graph_def,
                     &inputs, &outputs));
if (!s.ok())
{
    cerr << "Could not freeze model: " << s.error_message() << endl;
    return -1;
}

否则,您可以直接使用保存的模型对象:

Otherwise, you can work directly with the saved model object:

// Default "serving_default" signature name set by tf.saved_model_save
const SignatureDef& signature_def = bundle.GetSignatures().at("serving_default");
// Get input and output names (different from layer names)
// Key is input and output layer names
const string input_name = signature_def.inputs().at("my_input").name();
const string output_name = signature_def.inputs().at("my_output").name();
// Run model
Tensor input = ...;
std::vector<Tensor> outputs;
s = bundle.session->Run({{input_name, input}}, {output_name}, {}, &outputs));
if (!s.ok())
{
    cerr << "Error running model: " << s.error_message() << endl;
    return -1;
}
// Get result
Tensor& output = outputs[0];

这篇关于如何从给定的模型获取Graph(或GraphDef)?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆