Tensorflow在C ++中导出和运行图形的不同方法 [英] Tensorflow Different ways to Export and Run graph in C++

查看:96
本文介绍了Tensorflow在C ++中导出和运行图形的不同方法的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

要将训练有素的网络导入C ++,您需要导出网络才能这样做。经过大量搜索并发现几乎没有任何信息之后,我们澄清说我们应该使用 freeze_graph()就能做到。

For importing your trained network to the C++ you need to export your network to be able to do so. After searching a lot and finding almost no information about it, it was clarified that we should use freeze_graph() to be able to do it.

感谢Tensorflow的新0.7版本,他们添加了文档

Thanks to the new 0.7 version of Tensorflow, they added documentation of it.

查阅文档后,我发现几乎没有类似的方法,您能说出 freeze_graph()和:
tf.train.export_meta_graph ,因为它具有相似的参数,但它似乎也可以用于将模型导入C ++(我只是猜测区别在于,使用这种方法输出的文件只能使用 import_graph_def()或其他方式?)

After looking into documentations, I found that there are few similar methods, can you tell what is the difference between freeze_graph() and: tf.train.export_meta_graph as it has similar parameters, but it seems it can also be used for importing models to C++ (I just guess the difference is that for using the file output by this method you can only use import_graph_def() or it's something else?)

还有一个有关如何使用 write_graph()的问题:
在文档中给出了 graph_def sess.graph_def 提供,但在 freeze_graph()中的示例中,它为 sess.graph.as_graph_def ()。两者有什么区别?

Also one question about how to use write_graph(): In documentations the graph_def is given by sess.graph_def but in examples in freeze_graph() it is sess.graph.as_graph_def(). What is the difference between these two?

此问题与此问题有关。

谢谢!

推荐答案

这是我的解决方案利用TF 0.12中引入的V2检查点。

Here's my solution utilizing the V2 checkpoints introduced in TF 0.12.

无需将所有变量都转换为常量或冻结图表

There's no need to convert all variables to constants or freeze the graph.

为清楚起见,V2检查点在我的目录模型

Just for clarity, a V2 checkpoint looks like this in my directory models:

checkpoint  # some information on the name of the files in the checkpoint
my-model.data-00000-of-00001  # the saved weights
my-model.index  # probably definition of data layout in the previous file
my-model.meta  # protobuf of the graph (nodes and topology info)

Python部分(保存)

with tf.Session() as sess:
    tf.train.Saver(tf.trainable_variables()).save(sess, 'models/my-model')

如果使用 tf。创建 Saver $ c>,您可以节省一些头痛和存储空间。但是,也许某些更复杂的模型需要保存所有数据,然后将此参数删除到 Saver 中,只需确保您正在创建 Saver 之后被创建。为所有变量/层赋予唯一的名称也是非常明智的,否则您可以在不同的问题中运行。

If you create the Saver with tf.trainable_variables(), you can save yourself some headache and storage space. But maybe some more complicated models need all data to be saved, then remove this argument to Saver, just make sure you're creating the Saver after your graph is created. It is also very wise to give all variables/layers unique names, otherwise you can run in different problems.

Python部分(推断)

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('models/my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('models/'))
    outputTensors = sess.run(outputOps, feed_dict=feedDict)

C ++部分(推论)

请注意, checkpointPath 并非任何现有文件的路径,只是它们的通用前缀。如果您错误地将文件放置到 .index 文件的路径,则TF不会告诉您这是错误的,但是由于未初始化的变量,它会在推断过程中消失。

Note that checkpointPath isn't a path to any of the existing files, just their common prefix. If you mistakenly put there path to the .index file, TF won't tell you that was wrong, but it will die during inference due to uninitialized variables.

#include <tensorflow/core/public/session.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>

using namespace std;
using namespace tensorflow;

...
// set up your input paths
const string pathToGraph = "models/my-model.meta"
const string checkpointPath = "models/my-model";
...

auto session = NewSession(SessionOptions());
if (session == nullptr) {
    throw runtime_error("Could not create Tensorflow session.");
}

Status status;

// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
if (!status.ok()) {
    throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}

// Add the graph to the session
status = session->Create(graph_def.graph_def());
if (!status.ok()) {
    throw runtime_error("Error creating graph: " + status.ToString());
}

// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath;
status = session->Run(
        {{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
        {},
        {graph_def.saver_def().restore_op_name()},
        nullptr);
if (!status.ok()) {
    throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}

// and run the inference to your liking
auto feedDict = ...
auto outputOps = ...
std::vector<tensorflow::Tensor> outputTensors;
status = session->Run(feedDict, outputOps, {}, &outputTensors);

这篇关于Tensorflow在C ++中导出和运行图形的不同方法的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆