查找 tensorflow op 依赖的所有变量 [英] Find all variables that a tensorflow op depends upon

查看:34
本文介绍了查找 tensorflow op 依赖的所有变量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

有没有办法找到给定操作(通常是损失)所依赖的所有变量?我想使用它然后将此集合传递到 optimizer.minimize()tf.gradients() 使用各种 set().intersection() 组合.

到目前为止,我已经找到了 op.op.inputs 并尝试了一个简单的 BFS,但我从来没有遇到 Variable 对象返回的 tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)slim.get_variables()

相应的Tensor.op._idVariables.op._id"字段之间似乎存在对应关系,但我不确定这是我应该依赖的东西.

或者,也许我一开始就不应该这样做?我当然可以在构建图形时精心构建不相交的变量集,但是如果我更改模型,很容易遗漏一些东西.

解决方案

tf.Variable.op 的文档不是特别清楚,但它确实引用了 tf.Variable 的实现:任何依赖于 tf.Variable 的操作都将开启该操作的路径.由于 tf.Operation 对象是可散列的,您可以将其用作 dict 的键,将 tf.Operation 对象映射到相应的 tf.Variable 对象,然后像之前一样执行 BFS:

op_to_var = {var.op: var for var in tf.trainable_variables()}起始操作 = ...从属变量 = []队列 = collections.deque()queue.append(starting_op)访问 = 设置([starting_op])排队时:op = queue.popleft()尝试:dependent_vars.append(op_to_var[op])除了 KeyError:# `op` 不是一个变量,所以搜索它的输入(如果有的话).对于 op.inputs 中的 op_input:如果 op_input.op 未访问:queue.append(op_input.op)访问.添加(op_input.op)

Is there a way to find all variables that a given operation (usually a loss) depends upon? I would like to use this to then pass this collection into optimizer.minimize() or tf.gradients() using various set().intersection() combinations.

So far I have found op.op.inputs and tried a simple BFS on that, but I never chance upon Variable objects as returned by tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) or slim.get_variables()

There does seem to be a correspondence between corresponding 'Tensor.op._idandVariables.op._id` fields, but I'm not sure that's a something I should rely upon.

Or maybe I should't want to do this in the first place? I could of course construct my disjoint sets of variables meticulously while building my graph, but then it would be easy to miss something if I change the model.

解决方案

The documentation for tf.Variable.op is not particularly clear, but it does refer to the crucial tf.Operation used in the implementation of a tf.Variable: any op that depends on a tf.Variable will be on a path from that operation. Since the tf.Operation object is hashable, you can use it as the key of a dict that maps tf.Operation objects to the corresponding tf.Variable object, and then perform the BFS as before:

op_to_var = {var.op: var for var in tf.trainable_variables()}

starting_op = ...
dependent_vars = []

queue = collections.deque()
queue.append(starting_op)

visited = set([starting_op])

while queue:
  op = queue.popleft()
  try:
    dependent_vars.append(op_to_var[op])
  except KeyError:
    # `op` is not a variable, so search its inputs (if any). 
    for op_input in op.inputs:
      if op_input.op not in visited:
        queue.append(op_input.op)
        visited.add(op_input.op)

这篇关于查找 tensorflow op 依赖的所有变量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆