如何列出节点依赖的所有 Tensorflow 变量? [英] How can I list all Tensorflow variables a node depends on?
问题描述
如何列出节点依赖的所有 Tensorflow 变量/常量/占位符?
How can I list all Tensorflow variables/constants/placeholders a node depends on?
示例 1(添加常量):
Example 1 (addition of constants):
import tensorflow as tf
a = tf.constant(1, name = 'a')
b = tf.constant(3, name = 'b')
c = tf.constant(9, name = 'c')
d = tf.add(a, b, name='d')
e = tf.add(d, c, name='e')
sess = tf.Session()
print(sess.run([d, e]))
我想要一个函数 list_dependencies()
如:
I would like to have a function list_dependencies()
such as:
list_dependencies(d)
返回['a', 'b']
list_dependencies(e)
返回['a', 'b', 'c']
list_dependencies(d)
returns['a', 'b']
list_dependencies(e)
returns['a', 'b', 'c']
示例 2(占位符和权重矩阵之间的矩阵乘法,然后添加偏置向量):
Example 2 (matrix multiplication between a placeholder and a weight matrix, followed by the addition of a bias vector):
tf.set_random_seed(1)
input_size = 5
output_size = 3
input = tf.placeholder(tf.float32, shape=[1, input_size], name='input')
W = tf.get_variable(
"W",
shape=[input_size, output_size],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.get_variable(
"b",
shape=[output_size],
initializer=tf.constant_initializer(2))
output = tf.matmul(input, W, name="output")
output_bias = tf.nn.xw_plus_b(input, W, b, name="output_bias")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run([output,output_bias], feed_dict={input: [[2]*input_size]}))
我想要一个函数 list_dependencies()
如:
I would like to have a function list_dependencies()
such as:
list_dependencies(output)
返回['W', 'input']
list_dependencies(output_bias)
返回['W', 'b', 'input']
list_dependencies(output)
returns['W', 'input']
list_dependencies(output_bias)
returns['W', 'b', 'input']
推荐答案
以下是我为此使用的实用程序(来自 https://github.com/yaroslavvb/stuff/blob/master/linearize/linearize.py)
Here are utilities I use for this (from https://github.com/yaroslavvb/stuff/blob/master/linearize/linearize.py)
# computation flows from parents to children
def parents(op):
return set(input.op for input in op.inputs)
def children(op):
return set(op for out in op.outputs for op in out.consumers())
def get_graph():
"""Creates dictionary {node: {child1, child2, ..},..} for current
TensorFlow graph. Result is compatible with networkx/toposort"""
ops = tf.get_default_graph().get_operations()
return {op: children(op) for op in ops}
def print_tf_graph(graph):
"""Prints tensorflow graph in dictionary form."""
for node in graph:
for child in graph[node]:
print("%s -> %s" % (node.name, child.name))
这些函数适用于 ops.要获得生成张量 t
的操作,请使用 t.op
.要获得 op op
生成的张量,请使用 op.outputs
These functions work on ops. To get an op that produces tensor t
, use t.op
. To get tensors produced by op op
, use op.outputs
这篇关于如何列出节点依赖的所有 Tensorflow 变量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!