Tensorflow:如何替换计算图中的节点? [英] Tensorflow: How to replace a node in a calculation graph?

查看:49
本文介绍了Tensorflow:如何替换计算图中的节点?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

如果你有两个不相交的图,并想把它们联系起来,把这个:

If you have two disjoint graphs, and want to link them, turning this:

x = tf.placeholder('float')
y = f(x)

y = tf.placeholder('float')
z = f(y)

进入这个:

x = tf.placeholder('float')
y = f(x)
z = g(y)

有没有办法做到这一点?在某些情况下,它似乎可以使施工更容易.

Is there a way to do that? It seems like it could make construction easier in some cases.

例如,如果您有一个图形,将输入图像作为 tf.placeholder,并且想要优化输入图像,深梦风格,有没有办法只替换占位符使用 tf.variable 节点?或者你必须在构建图表之前考虑这一点?

For example if you have a graph that has the input image as a tf.placeholder, and want to optimize the input image, deep-dream style, is there a way to just replace the placeholder with a tf.variable node? Or do you have to think of that before building the graph?

推荐答案

TL;DR:如果您可以将这两个计算定义为 Python 函数,那么您应该这样做.如果您不能,TensorFlow 中有更高级的功能来序列化和导入图形,这允许您组合来自不同来源的图形.

TL;DR: If you can define the two computations as Python functions, you should do that. If you can't, there's more advanced functionality in TensorFlow to serialize and import graphs, which allows you to compose graphs from different sources.

在 TensorFlow 中执行此操作的一种方法是将不相交的计算构建为单独的 tf.Graph 对象,然后使用 Graph.as_graph_def():

One way to do this in TensorFlow is to build the disjoint computations as separate tf.Graph objects, then convert them to serialized protocol buffers using Graph.as_graph_def():

with tf.Graph().as_default() as g_1:
  input = tf.placeholder(tf.float32, name="input")
  y = f(input)
  # NOTE: using identity to get a known name for the output tensor.
  output = tf.identity(y, name="output")

gdef_1 = g_1.as_graph_def()

with tf.Graph().as_default() as g_2:  # NOTE: g_2 not g_1       
  input = tf.placeholder(tf.float32, name="input")
  z = g(input)
  output = tf.identity(y, name="output")

gdef_2 = g_2.as_graph_def()

然后您可以使用 tf.import_graph_def():

Then you could compose gdef_1 and gdef_2 into a third graph, using tf.import_graph_def():

with tf.Graph().as_default() as g_combined:
  x = tf.placeholder(tf.float32, name="")

  # Import gdef_1, which performs f(x).
  # "input:0" and "output:0" are the names of tensors in gdef_1.
  y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
                           return_elements=["output:0"])

  # Import gdef_2, which performs g(y)
  z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
                           return_elements=["output:0"]

这篇关于Tensorflow:如何替换计算图中的节点?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆