Tensorflow:如何替换计算图中的节点? [英] Tensorflow: How to replace a node in a calculation graph?
问题描述
如果你有两个不相交的图,并想把它们联系起来,把这个:
If you have two disjoint graphs, and want to link them, turning this:
x = tf.placeholder('float')
y = f(x)
y = tf.placeholder('float')
z = f(y)
进入这个:
x = tf.placeholder('float')
y = f(x)
z = g(y)
有没有办法做到这一点?在某些情况下,它似乎可以使施工更容易.
Is there a way to do that? It seems like it could make construction easier in some cases.
例如,如果您有一个图形,将输入图像作为 tf.placeholder
,并且想要优化输入图像,深梦风格,有没有办法只替换占位符使用 tf.variable
节点?或者你必须在构建图表之前考虑这一点?
For example if you have a graph that has the input image as a tf.placeholder
, and want to optimize the input image, deep-dream style, is there a way to just replace the placeholder with a tf.variable
node? Or do you have to think of that before building the graph?
推荐答案
TL;DR:如果您可以将这两个计算定义为 Python 函数,那么您应该这样做.如果您不能,TensorFlow 中有更高级的功能来序列化和导入图形,这允许您组合来自不同来源的图形.
TL;DR: If you can define the two computations as Python functions, you should do that. If you can't, there's more advanced functionality in TensorFlow to serialize and import graphs, which allows you to compose graphs from different sources.
在 TensorFlow 中执行此操作的一种方法是将不相交的计算构建为单独的 tf.Graph
对象,然后使用 Graph.as_graph_def()
:
One way to do this in TensorFlow is to build the disjoint computations as separate tf.Graph
objects, then convert them to serialized protocol buffers using Graph.as_graph_def()
:
with tf.Graph().as_default() as g_1:
input = tf.placeholder(tf.float32, name="input")
y = f(input)
# NOTE: using identity to get a known name for the output tensor.
output = tf.identity(y, name="output")
gdef_1 = g_1.as_graph_def()
with tf.Graph().as_default() as g_2: # NOTE: g_2 not g_1
input = tf.placeholder(tf.float32, name="input")
z = g(input)
output = tf.identity(y, name="output")
gdef_2 = g_2.as_graph_def()
然后您可以使用 tf.import_graph_def()
:
Then you could compose gdef_1
and gdef_2
into a third graph, using tf.import_graph_def()
:
with tf.Graph().as_default() as g_combined:
x = tf.placeholder(tf.float32, name="")
# Import gdef_1, which performs f(x).
# "input:0" and "output:0" are the names of tensors in gdef_1.
y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
return_elements=["output:0"])
# Import gdef_2, which performs g(y)
z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
return_elements=["output:0"]
这篇关于Tensorflow:如何替换计算图中的节点?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!