对 `tf.cond` 的行为感到困惑 [英] Confused by the behavior of `tf.cond`

查看:34
本文介绍了对 `tf.cond` 的行为感到困惑的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我的图中需要一个条件控制流.如果 predTrue,图应该调用一个更新变量然后返回它的操作,否则它返回变量不变.简化版是:

I need a conditional control flow in my graph. If pred is True, the graph should call an op that updates a variable and then returns it, otherwise it returns the variable unchanged. A simplified version is:

pred = tf.constant(True)
x = tf.Variable([1])
assign_x_2 = tf.assign(x, [2])
def update_x_2():
  with tf.control_dependencies([assign_x_2]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())

然而,我发现 pred=Truepred=False 导致相同的结果 y=[2],这意味着当 tf.cond 未选择 update_x_2 时,也会调用分配操作.这怎么解释?以及如何解决这个问题?

However, I find that both pred=True and pred=False lead to the same result y=[2], which means the assign op is also called when update_x_2 is not selected by tf.cond. How to explain this? And how to solve this problem?

推荐答案

TL;DR: 如果您想要 tf.cond() 在其中一个分支中执行副作用(如赋值),您必须在传递给 tf.cond() 的函数内部创建执行副作用的操作.

TL;DR: If you want tf.cond() to perform a side effect (like an assignment) in one of the branches, you must create the op that performs the side effect inside the function that you pass to tf.cond().

tf.cond() 的行为有点不直观.由于 TensorFlow 图中的执行通过图向前流动,因此您在 任一 分支中引用的所有操作都必须在评估条件之前执行.这意味着 true 和 false 分支都接收对 tf.assign() 操作的控制依赖,因此 y 始终设置为 2,即使 pred 是 False.

The behavior of tf.cond() is a little unintuitive. Because execution in a TensorFlow graph flows forward through the graph, all operations that you refer to in either branch must execute before the conditional is evaluated. This means that both the true and the false branches receive a control dependency on the tf.assign() op, and so y always gets set to 2, even if pred is False.

解决方案是在定义真正分支的函数内创建 tf.assign() 操作.例如,您可以按如下方式构建代码:

The solution is to create the tf.assign() op inside the function that defines the true branch. For example, you could structure your code as follows:

pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
def update_x_2():
  with tf.control_dependencies([tf.assign(x, [2])]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval(feed_dict={pred: False}))  # ==> [1]
  print(y.eval(feed_dict={pred: True}))   # ==> [2]

这篇关于对 `tf.cond` 的行为感到困惑的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆