删除TensorFlow图中的少数几个节点 [英] Deleting all but a few nodes in TensorFlow graph
问题描述
我的TensorFlow用例要求我为每个需要处理的实例建立一个新的计算图。
My use-case of TensorFlow requires me to build a new computation graph for each instance that needs to be processed. This ends up blowing up the memory requirements.
除了一些作为模型参数的 tf.Variables
外,我想要删除所有其他节点。其他有类似问题的人发现 tf.reset_default_graph()
很有用,但是这将摆脱我需要坚持的模型参数。
Apart from a few tf.Variables
that are model parameters, I'd like to delete all other nodes. Other people with similar problems have found tf.reset_default_graph()
to be useful, but this would get rid of the model parameters that I need to persist.
除了这些节点外,我还能用来删除所有节点吗?
What can I use to delete all but these nodes?
编辑:
实例特定的计算实际上只是意味着我在添加了很多新操作。我相信这些操作是内存问题背后的原因。
The instance specific computations actually just means I am adding a lot new operations. I believe these operations are the reason behind the memory issues.
更新:
请参阅最近发布的tensorflow折叠( https://github.com/tensorflow/fold ),它允许动态构建计算图。
UPDATE: See the recently released tensorflow fold (https://github.com/tensorflow/fold) which allows dynamic construction of computation graphs.
推荐答案
tf.graph数据结构被设计为仅追加数据结构。因此,不可能删除或修改现有节点。通常这不是问题,因为在运行会话时仅处理必要的子图。
The tf.graph data-structure is designed to be an append-only data-structure. It is therefore not possible to remove or modify existing nodes. Usually this is not a problem, as only the necessary subgraph is processed when running a session.
您可以尝试将图形的Variabels复制到新图形中并删除旧图形。要对其进行存档,只需运行:
What you can try is to copy the Variabels of your graph into a new graph and delete the old one. To archive this just run:
old_graph = tf.get_default_graph() # Save the old graph for later iteration
new_graph = tf.graph() # Create an empty graph
new_graph.set_default() # Makes the new graph default
如果要遍历旧图中的所有节点,请使用:
If you want to iterate over all nodes in the old graph use:
for node in old_graph.get_operations():
if node.type == 'Variable':
# read value of variable and copy it into new Graph
或者,您可以使用:
for node in old_graph.get_collection('trainable_variables'):
# iterates over all trainable Variabels
# read and create new variable
还可以查看 python / framework / ops.py:1759
来了解更多处理图节点的方法。
Have also a look at python/framework/ops.py : 1759
to see more ways on manipulating nodes in graph.
但是,在弄乱 tf.Graph
之前,我强烈建议您考虑一下是否真的需要。通常,人们可以尝试概括计算并使用共享变量来构建图,这样您要处理的每个实例都是该图的子图。
However before you mess around with tf.Graph
I would strongly recommend to consider whether this is really required. Usually one can try to generalize the computation and use shared variables build a graph, so that each instance you want to process is a subgraph of this graph.
这篇关于删除TensorFlow图中的少数几个节点的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!