在函数内部构建 Tensorflow 图 [英] Building Tensorflow Graphs Inside of Functions

查看:24
本文介绍了在函数内部构建 Tensorflow 图的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在学习 Tensorflow 并尝试正确构建我的代码.我(或多或少)知道如何构建图形,无论是裸图还是类方法,但我试图找出如何最好地构建代码.我试过这个简单的例子:

I'm learning Tensorflow and am trying to properly structure my code. I (more or less) know how to build graphs either bare or as class methods, but I'm trying to figure out how best to structure the code. I've tried the simple example:

def build_graph():                
     g = tf.Graph()     
     with g.as_default():                       
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     return g   

graph = build_graph()
with tf.Session(graph=graph) as sess:
     feed = {a: 3}      
     print(sess.run(b, feed_dict=feed))

应该只打印出 4.但是,当我这样做时,我收到错误:

which should just print out 4. However, when I do that, I get the error:

Cannot interpret feed_dict key as Tensor: Tensor 
Tensor("Placeholder:0", dtype=int8) is not an element of this graph.

我很确定这是因为函数 build_graph 中的占位符是私有的,但是 with tf.Session(graph=graph) 不应该注意那个?在这种情况下,有没有更好的方法来使用提要字典?

I'm pretty sure this is because the placeholder inside the function build_graph is private, but shouldn't the with tf.Session(graph=graph) take care of that? Is there a better way of using a feed dict in a situation like this?

推荐答案

有几个选项.

选项 1:只传递张量的名称而不是张量本身.

Option 1: just pass the name of the tensor instead of the tensor itself.

with tf.Session(graph=graph) as sess:
    feed = {"Placeholder:0": 3}      
    print(sess.run("Add:0", feed_dict=feed))

在这种情况下,最好为节点指定有意义的名称,而不是使用上面的默认名称:

In this case, it's probably best to give the nodes meaningful names, instead of using the default names as above:

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8, name="a")
         b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
     return g

graph = build_graph()
with tf.Session(graph=graph) as sess:
     feed = {"a:0": 3}
     print(sess.run("b:0", feed_dict=feed))

回想一下,名为 "foo" 的操作的输出是名为 "foo:0""foo:1" 的张量,等等.大多数操作只有一个输出.

Recall that the outputs of an operation named "foo" are tensors named "foo:0", "foo:1", and so on. Most operations have just one output.

选项 2:让您的 build_graph() 函数返回所有重要节点.

Option 2: make your build_graph() function return all the important nodes.

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     return g, a, b

graph, a, b = build_graph()
with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))

选项 3:将重要节点添加到集合中

Option 3: add important nodes to a collection

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     for node in (a, b):
         g.add_to_collection("important_stuff", node)
     return g

graph = build_graph()
a, b = graph.get_collection("important_stuff")
with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))

选项 4:根据@pohe 的建议,您可以使用 get_tensor_by_name()

Option 4: as suggested by @pohe you can use get_tensor_by_name()

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8, name="a")
         b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
     return g

graph = build_graph()
a, b = [graph.get_tensor_by_name(name) for name in ("a:0", "b:0")]
with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))

我个人最常使用选项 2,它非常简单,不需要玩弄名字.当图形很大并且会持续很长时间时,我使用选项 3,因为集合与模型一起保存,这是记录真正重要的内容的快速方法.我并没有真正使用选项 1,因为我更喜欢对对象进行实际引用(不知道为什么).当您处理由其他人构建的图形时,选项 4 很有用,并且他们没有给您直接引用张量.

I personally use option 2 most often, it's pretty straightforward and doesn't require playing with names. I use option 3 when the graph is large and will live for a long time, because the collection gets saved along with the model, and it's a quick way to document what really matters. I don't really use option 1, because I prefer to have actual references to objects (not sure why). Option 4 is useful when you are working with a graph built by someone else, and they didn't give you direct references to tensors.

希望这会有所帮助!

这篇关于在函数内部构建 Tensorflow 图的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆