是否有关于如何生成包含训练过的 TensorFlow 图的 protobuf 文件的示例 [英] Is there an example on how to generate protobuf files holding trained TensorFlow graphs

查看:21
本文介绍了是否有关于如何生成包含训练过的 TensorFlow 图的 protobuf 文件的示例的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在查看关于如何部署的 Google 示例并在 Android 上使用预训练的 Tensorflow 图(模型).此示例使用位于以下位置的 .pb 文件:

I am looking at Google's example on how to deploy and use a pre-trained Tensorflow graph (model) on Android. This example uses a .pb file at:

https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip

这是一个自动下载文件的链接.

示例展示了如何将 .pb 文件加载到 Tensorflow 会话并使用它进行分类,但似乎没有提到如何生成这样的 .pb 文件,在训练图之后(例如,在 Python 中).

The example shows how to load the .pb file to a Tensorflow session and use it to perform classification, but it doesn't seem to mention how to generate such a .pb file, after a graph is trained (e.g., in Python).

有没有关于如何做到这一点的例子?

Are there any examples on how to do that?

推荐答案

freeze_graph.py 脚本,它是 TensorFlow 存储库的一部分,现在用作生成协议缓冲区的工具,表示冻结"训练模型,来自现有的 TensorFlow GraphDef 和保存的检查点.它使用与下面描述的相同的步骤,但使用起来要容易得多.

The freeze_graph.py script, which is part of the TensorFlow repository, now serves as a tool that generates a protocol buffer representing a "frozen" trained model, from an existing TensorFlow GraphDef and a saved checkpoint. It uses the same steps as described below, but it much easier to use.

目前该过程没有很好的文档记录(并且有待改进),但大致步骤如下:

Currently the process isn't very well documented (and subject to refinement), but the approximate steps are as follows:

  1. 将您的模型构建和训练为名为 g_1tf.Graph.
  2. 获取每个变量的最终值并将它们存储为 numpy 数组(使用 Session.run()).
  3. 在名为 g_2 的新 tf.Graph 中,创建 tf.constant() 每个变量的张量,使用相应 numpy 数组的值在第 2 步中获取.
  4. 使用tf.import_graph_def() 将节点从 g_1 复制到 g_2,并使用 input_map 参数替换 中的每个变量g_1 与相应的 tf.constant() 张量在第 3 步中创建.您可能还想使用 input_map 指定一个新的输入张量(例如替换输入管道,带有tf.placeholder()).使用 return_elements 参数指定预测输出张量的名称.

  1. Build and train your model as a tf.Graph called g_1.
  2. Fetch the final values of each of the variables and store them as numpy arrays (using Session.run()).
  3. In a new tf.Graph called g_2, create tf.constant() tensors for each of the variables, using the value of the corresponding numpy array fetched in step 2.
  4. Use tf.import_graph_def() to copy nodes from g_1 into g_2, and use the input_map argument to replace each variable in g_1 with the corresponding tf.constant() tensors created in step 3. You may also want to use input_map to specify a new input tensor (e.g. replacing an input pipeline with a tf.placeholder()). Use the return_elements argument to specify the name of the predicted output tensor.

调用 g_2.as_graph_def() 以获取图的协议缓冲区表示.

Call g_2.as_graph_def() to get a protocol buffer representation of the graph.

(注意:生成的图中会有额外的节点用于训练.虽然它不是公共API的一部分,但您可能希望使用内部的graph_util.extract_sub_graph() 函数从图中剥离这些节点.)

(NOTE: The generated graph will have extra nodes in the graph for training. Although it is not part of the public API, you may wish to use the internal graph_util.extract_sub_graph() function to strip these nodes from the graph.)

这篇关于是否有关于如何生成包含训练过的 TensorFlow 图的 protobuf 文件的示例的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆