tf.assign 给变量 slice 在 tf.while_loop 中不起作用 [英] tf.assign to variable slice doesn't work inside tf.while_loop

查看:22
本文介绍了tf.assign 给变量 slice 在 tf.while_loop 中不起作用的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

下面的代码有什么问题?tf.assign 操作在应用于 tf.Variable 的切片时工作得很好,如果它发生在循环之外.但是,在这种情况下,它给出了以下错误.

What is wrong with the following code? The tf.assign op works just fine when applied to a slice of a tf.Variable if it happens outside of a loop. But, in this context, it gives the error below.

import tensorflow as tf

v = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
n = len(v)
a = tf.Variable(v, name = 'a')

def cond(i, a):
    return i < n 

def body(i, a):
    tf.assign(a[i], a[i-1] + a[i-2])
    return i + 1, a

i, b = tf.while_loop(cond, body, [2, a]) 

结果:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3210, in while_loop
    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2942, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2879, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/home/hrbigelow/ai/lb-wavenet/while_var_test.py", line 11, in body
    tf.assign(a[i], a[i-1] + a[i-2])
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/state_ops.py", line 220, in assign
    return ref.assign(value, name=name)
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 697, in assign
    raise ValueError("Sliced assignment is only supported for variables")
ValueError: Sliced assignment is only supported for variables

推荐答案

您的变量不是循环内运行的操作的输出,它是循环外的外部实体.因此,您不必将其作为参数提供.

Your variable is not an output of the operations run inside your loop, it is an external entity living outside the loop. So you do not have to provide it as an argument.

此外,您需要强制执行更新,例如在 body 中使用 tf.control_dependencies.

Also, you need to enforce the update to take place, for example using tf.control_dependencies in body.

import tensorflow as tf

v = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
n = len(v)
a = tf.Variable(v, name = 'a')

def cond(i):
    return i < n 

def body(i):
    op = tf.assign(a[i], a[i-1] + a[i-2])
    with tf.control_dependencies([op]):
      return i + 1

i = tf.while_loop(cond, body, [2])

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
i.eval()
print(a.eval())
# [ 1  1  2  3  5  8 13 21 34 55 89]

您可能需要谨慎并设置 parallel_iterations=1 以强制循环按顺序运行.

Possibly you may want to be cautious and set parallel_iterations=1 to enforce the loop to run sequentially.

这篇关于tf.assign 给变量 slice 在 tf.while_loop 中不起作用的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆