如何列出节点依赖的所有 Tensorflow 变量? [英] How can I list all Tensorflow variables a node depends on?

查看:24
本文介绍了如何列出节点依赖的所有 Tensorflow 变量?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

如何列出节点依赖的所有 Tensorflow 变量/常量/占位符?

How can I list all Tensorflow variables/constants/placeholders a node depends on?

示例 1(添加常量):

Example 1 (addition of constants):

import tensorflow as tf

a = tf.constant(1, name = 'a')
b = tf.constant(3, name = 'b')
c = tf.constant(9, name = 'c')
d = tf.add(a, b, name='d')
e = tf.add(d, c, name='e')

sess = tf.Session()
print(sess.run([d, e]))

我想要一个函数 list_dependencies() 如:

I would like to have a function list_dependencies() such as:

  • list_dependencies(d) 返回 ['a', 'b']
  • list_dependencies(e) 返回 ['a', 'b', 'c']
  • list_dependencies(d) returns ['a', 'b']
  • list_dependencies(e) returns ['a', 'b', 'c']

示例 2(占位符和权重矩阵之间的矩阵乘法,然后添加偏置向量):

Example 2 (matrix multiplication between a placeholder and a weight matrix, followed by the addition of a bias vector):

tf.set_random_seed(1)
input_size  = 5
output_size = 3
input       = tf.placeholder(tf.float32, shape=[1, input_size], name='input')
W           = tf.get_variable(
                "W",
                shape=[input_size, output_size],
                initializer=tf.contrib.layers.xavier_initializer())
b           = tf.get_variable(
                "b",
                shape=[output_size],
                initializer=tf.constant_initializer(2))
output      = tf.matmul(input, W, name="output")
output_bias = tf.nn.xw_plus_b(input, W, b, name="output_bias")

sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run([output,output_bias], feed_dict={input: [[2]*input_size]}))

我想要一个函数 list_dependencies() 如:

I would like to have a function list_dependencies() such as:

  • list_dependencies(output) 返回 ['W', 'input']
  • list_dependencies(output_bias) 返回 ['W', 'b', 'input']
  • list_dependencies(output) returns ['W', 'input']
  • list_dependencies(output_bias) returns ['W', 'b', 'input']

推荐答案

以下是我为此使用的实用程序(来自 https://github.com/yaroslavvb/stuff/blob/master/linearize/linearize.py)

Here are utilities I use for this (from https://github.com/yaroslavvb/stuff/blob/master/linearize/linearize.py)

# computation flows from parents to children

def parents(op):
  return set(input.op for input in op.inputs)

def children(op):
  return set(op for out in op.outputs for op in out.consumers())

def get_graph():
  """Creates dictionary {node: {child1, child2, ..},..} for current
  TensorFlow graph. Result is compatible with networkx/toposort"""

  ops = tf.get_default_graph().get_operations()
  return {op: children(op) for op in ops}


def print_tf_graph(graph):
  """Prints tensorflow graph in dictionary form."""
  for node in graph:
    for child in graph[node]:
      print("%s -> %s" % (node.name, child.name))

这些函数适用于 ops.要获得生成张量 t 的操作,请使用 t.op.要获得 op op 生成的张量,请使用 op.outputs

These functions work on ops. To get an op that produces tensor t, use t.op. To get tensors produced by op op, use op.outputs

这篇关于如何列出节点依赖的所有 Tensorflow 变量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆