查找 tensorflow op 依赖的所有变量 [英] Find all variables that a tensorflow op depends upon
问题描述
有没有办法找到给定操作(通常是损失)所依赖的所有变量?我想使用它然后将此集合传递到 optimizer.minimize()
或 tf.gradients()
使用各种 set().intersection()
组合.
到目前为止,我已经找到了 op.op.inputs
并尝试了一个简单的 BFS,但我从来没有遇到 Variable
对象返回的 tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
或 slim.get_variables()
相应的Tensor.op._id和
Variables.op._id"字段之间似乎存在对应关系,但我不确定这是我应该依赖的东西.>
或者,也许我一开始就不应该这样做?我当然可以在构建图形时精心构建不相交的变量集,但是如果我更改模型,很容易遗漏一些东西.
tf.Variable.op
的文档不是特别清楚,但它确实引用了 tf.Variable
的实现:任何依赖于 tf.Variable
的操作都将开启该操作的路径.由于 tf.Operation
对象是可散列的,您可以将其用作 dict
的键,将 tf.Operation
对象映射到相应的 tf.Variable
对象,然后像之前一样执行 BFS:
op_to_var = {var.op: var for var in tf.trainable_variables()}起始操作 = ...从属变量 = []队列 = collections.deque()queue.append(starting_op)访问 = 设置([starting_op])排队时:op = queue.popleft()尝试:dependent_vars.append(op_to_var[op])除了 KeyError:# `op` 不是一个变量,所以搜索它的输入(如果有的话).对于 op.inputs 中的 op_input:如果 op_input.op 未访问:queue.append(op_input.op)访问.添加(op_input.op)
Is there a way to find all variables that a given operation (usually a loss) depends upon?
I would like to use this to then pass this collection into optimizer.minimize()
or tf.gradients()
using various set().intersection()
combinations.
So far I have found op.op.inputs
and tried a simple BFS on that, but I never chance upon Variable
objects as returned by tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
or slim.get_variables()
There does seem to be a correspondence between corresponding 'Tensor.op._idand
Variables.op._id` fields, but I'm not sure that's a something I should rely upon.
Or maybe I should't want to do this in the first place? I could of course construct my disjoint sets of variables meticulously while building my graph, but then it would be easy to miss something if I change the model.
The documentation for tf.Variable.op
is not particularly clear, but it does refer to the crucial tf.Operation
used in the implementation of a tf.Variable
: any op that depends on a tf.Variable
will be on a path from that operation. Since the tf.Operation
object is hashable, you can use it as the key of a dict
that maps tf.Operation
objects to the corresponding tf.Variable
object, and then perform the BFS as before:
op_to_var = {var.op: var for var in tf.trainable_variables()}
starting_op = ...
dependent_vars = []
queue = collections.deque()
queue.append(starting_op)
visited = set([starting_op])
while queue:
op = queue.popleft()
try:
dependent_vars.append(op_to_var[op])
except KeyError:
# `op` is not a variable, so search its inputs (if any).
for op_input in op.inputs:
if op_input.op not in visited:
queue.append(op_input.op)
visited.add(op_input.op)
这篇关于查找 tensorflow op 依赖的所有变量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!