避免在Tensorflow中复制图(LSTM模型) [英] Avoiding duplicating graph in tensorflow (LSTM model)

查看:257
本文介绍了避免在Tensorflow中复制图(LSTM模型)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有以下简化代码(实际上是展开的LSTM模型):

I have the following simplified code (actually, unrolled LSTM model):

def func(a, b):
    with tf.variable_scope('name'):
        res = tf.add(a, b)
    print(res.name)
    return res

func(tf.constant(10), tf.constant(20))

每当我运行最后一行时,它似乎都会改变图形.但是我不想改变图表.实际上我的代码是神经网络模型,但是它太大了,因此我添加了上面的代码.我想在不更改模型图的情况下调用func,但它会更改.我在TensorFlow中阅读了有关变量作用域的内容,但似乎我根本不了解它.

Whenever I run the last line, it seems that it changes the graph. But I don't want the graph changes. Actually my code is different and is a neural network model but it is too huge, so I've added the above code. I want to call the func without changing the graph of model but it changes. I read about variable scope in TensorFlow but it seems that I've not understand it at all.

推荐答案

您应该查看tf.nn.dynamic_rnn的源代码,特别是_dynamic_rnn_loop函数. tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn.py#L647"rel =" nofollow noreferrer> python/ops/rnn.py -它正在解决相同的问题.为了不破坏图形,它使用tf.while_loop将相同的图形操作重新用于新数据.但是这种方法增加了一些限制,即在循环中通过的张量的形状必须为不变.请参见 tf.while_loop 文档中的示例:

You should take a look at the source code of tf.nn.dynamic_rnn, specifically _dynamic_rnn_loop function at python/ops/rnn.py - it's solving the same problem. In order not blow up the graph, it's using tf.while_loop to reuse the same graph ops for new data. But this approach adds several restrictions, namely the shape of tensors that are passing through in a loop must be invariant. See the examples in tf.while_loop documentation:

i0 = tf.constant(0)
m0 = tf.ones([2, 2])
c = lambda i, m: i < 10
b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
tf.while_loop(
    c, b, loop_vars=[i0, m0],
    shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])

这篇关于避免在Tensorflow中复制图(LSTM模型)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆