在 Tensorflow 中使用 tf.while_loop 更新变量 [英] Update a variable with tf.while_loop in Tensorflow

查看:35
本文介绍了在 Tensorflow 中使用 tf.while_loop 更新变量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想在 Tensorflow 中更新一个变量,因此我使用 tf.while_loop 如下:

I want to update a variable in Tensorflow and for that reason I use the tf.while_loop like:

a = tf.Variable([0, 0, 0, 0, 0, 0] , dtype = np.int16)

i = tf.constant(0)
size = tf.size(a)

def condition(i, size, a):
    return tf.less(i, size)

def body(i, size, a):
    a = tf.scatter_update(a, i , i)
    return [tf.add(i, 1), size, a]

r = tf.while_loop(condition, body, [i, size, a])

这是我正在尝试做的一个例子.发生的错误是 AttributeError: 'Tensor' object has no attribute '_lazy_read'.在 Tensorflow 中更新变量的适当方法是什么?

This is an example for what I am trying to do. The error that occurs is AttributeError: 'Tensor' object has no attribute '_lazy_read'. What is the appropriate way to update a variable in Tensorflow?

推荐答案

这在一个代码和执行之前并不明显.就像这样模式

This isn't obvious until one codes and executes. It is like this pattern

import tensorflow as tf


def cond(size, i):
    return tf.less(i,size)

def body(size, i):

    a = tf.get_variable("a",[6],dtype=tf.int32,initializer=tf.constant_initializer(0))
    a = tf.scatter_update(a,i,i)

    tf.get_variable_scope().reuse_variables() # Reuse variables
    with tf.control_dependencies([a]):
        return (size, i+1)

with tf.Session() as sess:

    i = tf.constant(0)
    size = tf.constant(6)
    _,i = tf.while_loop(cond,
                    body,
                    [size, i])

    a = tf.get_variable("a",[6],dtype=tf.int32)

    init = tf.initialize_all_variables()
    sess.run(init)

    print(sess.run([a,i]))

输出是

[数组([0, 1, 2, 3, 4, 5]), 6]

[array([0, 1, 2, 3, 4, 5]), 6]

  1. tf.get_variable使用这些参数获取现有变量或创建一个新的.
  2. tf.control_dependencies这是一个发生之前 关系.在这种情况下,我知道 scatter_update 发生在 while 递增和返回之前.没有这个它就不会更新.
  1. tf.get_variableGets an existing variable with these parameters or create a new one.
  2. tf.control_dependencies It is a happens-before relationship. In this case I understand that the scatter_update happens before the while increments and returns. It doesn't update without this.

注意:我并没有真正理解错误的含义或原因.我也明白了.

Note : I didn't really understand the meaning or cause of the error. I get that too.

这篇关于在 Tensorflow 中使用 tf.while_loop 更新变量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆