来自训练元图的权重和偏差 [英] Weights and Bias from Trained Meta Graph

查看:53
本文介绍了来自训练元图的权重和偏差的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已成功将重新训练的 InceptionV3 NN 导出为 TensorFlow 元图.我已经成功地将这个 protobuf 读回了 python,但我正在努力寻找一种导出每一层权重和偏差值的方法,我假设这些值存储在元图 protobuf 中,用于在 TensorFlow 之外重新创建 nn.

I have successfully exported a re-trained InceptionV3 NN as a TensorFlow meta graph. I have read this protobuf back into python successfully, but I am struggling to see a way to export each layers weight and bias values, which I am assuming is stored within the meta graph protobuf, for recreating the nn outside of TensorFlow.

我的工作流程是这样的:

My workflow is as such:

Retrain final layer for new categories
Export meta graph tf.train.export_meta_graph(filename='model.meta')
Build python pb2.py using Protoc and meta_graph.proto
Load Protobuf:

import meta_graph_pb2
saved = meta_graph_pb2.CollectionDef()
with open('model.meta', 'rb') as f:
  saved.ParseFromString(f.read())

从这里我可以查看图表的大多数方面,例如节点名称等,但我认为我的经验不足使得很难找到访问每个相关层的权重和偏差值的正确方法.

From here I can view most aspects of the graph, like node names and such, but I think my inexperience is making it difficult to track down the correct way to access the weight and bias values for each relevant layer.

推荐答案

MetaGraphDef proto 实际上并不包含权重和偏差的值.相反,它提供了一种将 GraphDef 与存储在一个或多个检查点文件中的权重相关联的方法,该文件由 tf.train.Saver.MetaGraphDef 教程有更多细节,但大致结构如下:

The MetaGraphDef proto doesn't actually contain the values of the weights and biases. Instead it provides a way to associate a GraphDef with the weights stored in one or more checkpoint files, written by a tf.train.Saver. The MetaGraphDef tutorial has more details, but the approximate structure is as follows:

  1. 在您的训练计划中,使用 tf.train.Saver 写出一个检查点.这也会将 MetaGraphDef 写入同一目录中的 .meta 文件.

  1. In you training program, write out a checkpoint using a tf.train.Saver. This will also write a MetaGraphDef to a .meta file in the same directory.

saver = tf.train.Saver(...)
# ...
saver.save(sess, "model")

您应该在检查点目录中找到名为 model.metamodel-NNNN(对于某些整数 NNNN)的文件.

You should find files called model.meta and model-NNNN (for some integer NNNN) in your checkpoint directory.

在另一个程序中,您可以导入您刚刚创建的 MetaGraphDef,并从检查点恢复.

In another program, you can import the MetaGraphDef you just created, and restore from a checkpoint.

saver = tf.train.import_meta_graph("model.meta")
saver.restore("model-NNNN")  # Or whatever checkpoint filename was written.

如果你想获取每个变量的值,你可以(例如)在tf.all_variables()集合中找到该变量并将其传递给sess.run() 获取其值.例如,要打印所有变量的值,您可以执行以下操作:

If you want to get the value of each variable, you can (for example) find the variable in tf.all_variables() collection and pass it to sess.run() to get its value. For example, to print the values of all variables, you can do the following:

for var in tf.all_variables():
  print var.name, sess.run(var)

您还可以过滤 tf.all_variables() 以找到您试图从模型中提取的特定权重和偏差.

You could also filter tf.all_variables() to find the particular weights and biases that you're trying to extract from the model.

这篇关于来自训练元图的权重和偏差的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆