删除TensorFlow图中的少数几个节点 [英] Deleting all but a few nodes in TensorFlow graph

查看:254
本文介绍了删除TensorFlow图中的少数几个节点的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我的TensorFlow用例要求我为每个需要处理的实例建立一个新的计算图。

My use-case of TensorFlow requires me to build a new computation graph for each instance that needs to be processed. This ends up blowing up the memory requirements.

除了一些作为模型参数的 tf.Variables 外,我想要删除所有其他节点。其他有类似问题的人发现 tf.reset_default_graph()很有用,但是这将摆脱我需要坚持的模型参数。

Apart from a few tf.Variables that are model parameters, I'd like to delete all other nodes. Other people with similar problems have found tf.reset_default_graph() to be useful, but this would get rid of the model parameters that I need to persist.

除了这些节点外,我还能用来删除所有节点吗?

What can I use to delete all but these nodes?

编辑:
实例特定的计算实际上只是意味着我在添加了很多新操作。我相信这些操作是内存问题背后的原因。

The instance specific computations actually just means I am adding a lot new operations. I believe these operations are the reason behind the memory issues.

更新:
请参阅最近发布的tensorflow折叠( https://github.com/tensorflow/fold ),它允许动态构建计算图。

UPDATE: See the recently released tensorflow fold (https://github.com/tensorflow/fold) which allows dynamic construction of computation graphs.

推荐答案

tf.graph数据结构被设计为仅追加数据结构。因此,不可能删除或修改现有节点。通常这不是问题,因为在运行会话时仅处理必要的子图。

The tf.graph data-structure is designed to be an append-only data-structure. It is therefore not possible to remove or modify existing nodes. Usually this is not a problem, as only the necessary subgraph is processed when running a session.

您可以尝试将图形的Variabels复制到新图形中并删除旧图形。要对其进行存档,只需运行:

What you can try is to copy the Variabels of your graph into a new graph and delete the old one. To archive this just run:

old_graph = tf.get_default_graph() # Save the old graph for later iteration
new_graph = tf.graph() # Create an empty graph
new_graph.set_default() # Makes the new graph default

如果要遍历旧图中的所有节点,请使用:

If you want to iterate over all nodes in the old graph use:

for node in old_graph.get_operations():
    if node.type == 'Variable':
       # read value of variable and copy it into new Graph

或者,您可以使用:

for node in old_graph.get_collection('trainable_variables'):
   # iterates over all trainable Variabels
   # read and create new variable

还可以查看 python / framework / ops.py:1759 来了解更多处理图节点的方法。

Have also a look at python/framework/ops.py : 1759 to see more ways on manipulating nodes in graph.

但是,在弄乱 tf.Graph 之前,我强烈建议您考虑一下是否真的需要。通常,人们可以尝试概括计算并使用共享变量来构建图,这样您要处理的每个实例都是该图的子图。

However before you mess around with tf.Graph I would strongly recommend to consider whether this is really required. Usually one can try to generalize the computation and use shared variables build a graph, so that each instance you want to process is a subgraph of this graph.

这篇关于删除TensorFlow图中的少数几个节点的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆