如何在 TensorFlow 中访问 protos 中的值? [英] How to access values in protos in TensorFlow?

查看:23
本文介绍了如何在 TensorFlow 中访问 protos 中的值?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我从 教程 中看到我们可以做到这一点:

<代码>对于 tf.get_default_graph().as_graph_def().node 中的节点:打印节点

在任意网络上完成后,我们会得到许多键值对.例如:

name: "conv2d_2/convolution"操作:Conv2D"输入:max_pooling2d/MaxPool"输入:conv2d_1/内核/读取"设备:/设备:GPU:0"属性{键:T"价值 {类型:DT_FLOAT}}属性{键:数据格式"价值 {s:NHWC"}}属性{键:填充"价值 {s:相同"}}属性{关键:大步"价值 {列表 {我:1我:1我:1我:1}}}属性{键:use_cudnn_on_gpu"价值 {乙:真的}}

如何访问所有这些值并将它们放入 Python 列表中?具体来说,我们如何获取strides"属性并将其中的所有1转换为[1, 1, 1, 1]?

解决方案

TLDR: 下面的代码是您可能想要使用的:

for n in tf.get_default_graph().as_graph_def().node:如果在 n.attr.keys() 中大步前进":打印 n.name, [int(a) for a in n.attr['strides'].list.i]如果 n.attr.keys() 中的形状":打印 n.name, [int(a.size) for a in n.attr['shape'].shape.dim]

这样做的诀窍是了解什么是protobufs.让我们通过 上面提到的教程.

首先声明:

<块引用>

 用于 graph_def.node 中的节点

<块引用>

每个节点都是一个 NodeDef 对象,定义在张量流/核心/框架/node_def.proto.这些是根本TensorFlow 图的构建块,每个图定义一个操作及其输入连接.这里有一个成员NodeDef,以及它们的含义.

注意 node_def.proto 中的以下内容:

  • 它导入 attr_value.proto.
  • 有name、op、input、device、attr等属性.具体来说,输入前面有一个 repeated 术语.我们暂时可以忽略这一点.

这与 Python 类完全一样,因此我们可以调用 node.name、node.op、node.input、node.device、node.attr 等.

我们现在想要访问的是 node.attr 中的内容.如果我们再次参考教程,它会指定:

<块引用>

这是一个保存节点所有属性的键/值存储.这些是节点的永久属性,不会改变的东西运行时,例如卷积过滤器的大小,或不断的操作.因为可以有很多不同的类型属性值,从字符串到整数,再到张量值数组,有一个单独的 protobuf 文件定义了数据结构将它们保存在 tensorflow/core/framework/attr_value.proto 中.

每个属性都有一个唯一的名称字符串,以及期望的属性在定义操作时列出.如果一个属性不是存在于节点中,但在操作中列出了默认值定义,在创建图形时使用该默认值.

您可以通过调用 node.name、node.op、等等.GraphDef 中存储的节点列表是一个完整的模型架构的定义.

由于这是一个键值存储,我们可以调用 n.attr.keys() 来查看该属性具有的键列表.如果有这样的键,我们可以进一步调用 n.attr['strides'] 来访问步幅.当我们尝试打印它时,我们得到以下信息:

list {我:1我:2我:2我:1}

这就是它开始变得混乱的地方,因为我们可能会尝试做 list(n.attr['strides']) 或类似的东西.如果我们查看 attr_value.proto,我们就可以理解发生了什么.我们看到它是oneof value,在本例中它是一个ListValue list,所以我们可以调用n.attr['strides'].list.如果我们打印这个,我们会得到以下内容:

i: 1我:1我:1我:1

我们接下来可能会尝试这样做:[a for a in n.attr['strides'].list][ai for a in n.attr['strides'].list].但是,没有任何效果.这是repeated 是一个需要理解的重要术语的地方.它基本上意味着有一个 int64 列表,您必须使用 i 属性访问它.执行 [int(a) for a in n.attr['strides'].list.i] 然后给了我们我们想要的,一个我们可以使用的 Python 列表:

[1, 1, 1, 1]

I see from the tutorial that we can do this:

for node in tf.get_default_graph().as_graph_def().node: print node

When done on an arbitrary network, we get many key value pairs. For example:

name: "conv2d_2/convolution"
op: "Conv2D"
input: "max_pooling2d/MaxPool"
input: "conv2d_1/kernel/read"
device: "/device:GPU:0"
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "data_format"
  value {
    s: "NHWC"
  }
}
attr {
  key: "padding"
  value {
    s: "SAME"
  }
}
attr {
  key: "strides"
  value {
    list {
      i: 1
      i: 1
      i: 1
      i: 1
    }
  }
}
attr {
  key: "use_cudnn_on_gpu"
  value {
    b: true
  }
}

How do I access all these values and put them in Python lists? Specifically, how can we get the "strides" attribute and convert all the 1s there into [1, 1, 1, 1]?

解决方案

TLDR: The code below is what you might want to use:

for n in tf.get_default_graph().as_graph_def().node:
    if 'strides' in n.attr.keys():
        print n.name, [int(a) for a in n.attr['strides'].list.i]
    if 'shape' in n.attr.keys():
        print n.name, [int(a.size) for a in n.attr['shape'].shape.dim]

The trick to doing this is to understand what protobufs are. Let's go through the tutorial mentioned above.

First of all, there's a statement:

for node in graph_def.node

Each node is a NodeDef object, defined in tensorflow/core/framework/node_def.proto. These are the fundamental building blocks of TensorFlow graphs, with each one defining a single operation along with its input connections. Here are the members of a NodeDef, and what they mean.

Note the following in node_def.proto:

  • It imports attr_value.proto.
  • There are attributes such as name, op, input, device, attr. Specifically, there's a repeated term in front of input. We can ignore this for now.

This works exactly like a Python class and we can thus call node.name, node.op, node.input, node.device, node.attr, etc.

What we would like to access now would be the contents in node.attr. If we refer to the tutorial once again, it specifies:

This is a key/value store holding all the attributes of a node. These are the permanent properties of nodes, things that don't change at runtime such as the size of filters for convolutions, or the values of constant ops. Because there can be so many different types of attribute values, from strings, to ints, to arrays of tensor values, there's a separate protobuf file defining the data structure that holds them, in tensorflow/core/framework/attr_value.proto.

Each attribute has a unique name string, and the expected attributes are listed when the operation is defined. If an attribute isn't present in a node, but it has a default listed in the operation definition, that default is used when the graph is created.

You can access all of these members by calling node.name, node.op, etc. in Python. The list of nodes stored in the GraphDef is a full definition of the model architecture.

Since this is a key-value store we can call n.attr.keys() to see a list of keys this attribute has. We can go further to call perhaps n.attr['strides'] to access the strides, if such a key is available. When we try to print this, we get the following:

list {
  i: 1
  i: 2
  i: 2
  i: 1
}

And this is where it starts to get confusing because we might try to do list(n.attr['strides']) or something of this sort. If we look at attr_value.proto, we can understand what's going on. We see that it's oneof value and in this case it's a ListValue list, so we can call n.attr['strides'].list. And if we print this, we get the following:

i: 1
i: 1
i: 1
i: 1

We might next try to do this: [a for a in n.attr['strides'].list] or [a.i for a in n.attr['strides'].list]. However, nothing works. This is where repeated is an important term to understand. It basically means that there's an int64 list and you have to access it with the i attribute. Doing [int(a) for a in n.attr['strides'].list.i] then gives us what we want, a Python list that we can use:

[1, 1, 1, 1]

这篇关于如何在 TensorFlow 中访问 protos 中的值?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆