TensorFlow 使用 tf.while_loop() 陷入无限循环 [英] TensorFlow stuck into endless loop using tf.while_loop()

查看:33
本文介绍了TensorFlow 使用 tf.while_loop() 陷入无限循环的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用 TensorFlow 实现一个需要使用 tf.while_loop()

I am using TensorFlow to implement a network that needs to use tf.while_loop()

import tensorflow as tf
import numpy as np
class model(object):
    def __init__(self):
        self.argmax_ep_gate_array = [ tf.placeholder(tf.int32, [None]) for _ in range(10)]
        argmax_ep_gate_array_concat = tf.concat(0, self.argmax_ep_gate_array)
        story_len = tf.constant(7)
        starter = tf.constant(0)
        z = []
        def body(hops):
            hops = tf.add(hops,1)
            z.append(hops)
            return hops
        def condition(hops):
            return tf.logical_and(tf.less(tf.gather(argmax_ep_gate_array_concat, hops),story_len),tf.less(hops,tf.constant(20)))

        self.gate_index = tf.while_loop(condition,body,[starter])
        self.z=tf.concat(0,z)

    def step(self, sess):
        feed={}
        for i in range(10):
            feed[self.argmax_ep_gate_array[i].name]=[i]
        print (sess.run([self.gate_index,self.z],feed))
with tf.Session() as sess:
    while_loop = model()
    sess.run(tf.initialize_all_variables())
    while_loop.step(sess)

你尝试了什么?

我发现如果我想 sess.run() body() 中没有返回的任何变量,tensorflow 会陷入无限循环.上面的例子是微不足道的,但它揭示了一些东西.在实际情况中,我使用 tf.while_loop() 运行一个 RNN,其中包括 y= wx+b 类似的东西,但是 wb 在 while 循环后不返回.在前向网络中,它工作正常.但是,如果我运行反向传播,程序将陷入无限循环.我想上面的代码重现了我的问题,因为反向传播确实需要修改 wb.或者有什么办法可以处理这个问题?

What have you tried?

I find that If I want to sess.run() any variable in the body() that is not returned, tensorflow would stuck into endless loop. The above example is trivial, but it reveals something. In the real case, I am using tf.while_loop() running a RNN which includes y= wx+b something like that, but the w and b are not returned after while loop. In the forward network, it works fine. However, if I run the back propagation, the program would stuck into endless loop. I suppose the code above reproducing my issue, because back propagation do need to modify w and b. Or is there any way to handle this issue?

推荐答案

TL;DR: 你不能存储在循环体中创建的张量供以后使用,因为这打破了一些关于如何使用的假设循环是结构化的.

TL;DR: You cannot store tensors that were created in the loop body for later use, because that breaks some assumptions about how the loop is structured.

一般来说,condition()body() 函数不能有副作用.实际上,您的程序不太可能具有预期的行为:TensorFlow 将执行 body() 函数一次,以构建必要的图结构,因此 z 在运行 model.__init__() 后将只包含一个元素.

In general, the condition() and body() functions must not have side effects. Indeed, it is unlikely that your program has the intended behavior: TensorFlow will execute the body() function once, to build the necessary graph structure, so z will only contain one element after running model.__init__().

相反,您必须在循环体中增量构建 z,使用 tf.concat() 并将值作为循环变量生成:

Instead, you must construct z incrementally in the loop body, using tf.concat() and producing the value as a loop variable:

starter = tf.constant(0)
z_initial = tf.constant([], dtype=tf.int32)

def body(hops, z_prev):
    hops = tf.add(hops, 1)
    z_next = tf.concat(0, [z_prev, tf.expand_dims(hops, 0)])
    return hops, z_next
def condition(hops, z):
    return tf.logical_and(tf.less(tf.gather(
        argmax_ep_gate_array_concat, hops), story_len), tf.less(hops, tf.constant(20)))

self.gate_index, self.z = tf.while_loop(condition,body,[starter, z_initial])

这篇关于TensorFlow 使用 tf.while_loop() 陷入无限循环的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆