如何在TensorFlow图中添加if条件? [英] How to add if condition in a TensorFlow graph?

查看:4731
本文介绍了如何在TensorFlow图中添加if条件?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设我有以下代码:

x = tf.placeholder("float32", shape=[None, ins_size**2*3], name = "x_input")
condition = tf.placeholder("int32", shape=[1, 1], name = "condition")
W = tf.Variable(tf.zeros([ins_size**2*3,label_option]), name = "weights")
b = tf.Variable(tf.zeros([label_option]), name = "bias")

if condition > 0:
    y = tf.nn.softmax(tf.matmul(x, W) + b)
else:
    y = tf.nn.softmax(tf.matmul(x, W) - b)  

if 语句是否有效计算(我不这么认为)?如果没有,如何在TensorFlow计算图中添加 if 语句?

Would the if statement work in the calculation (I do not think so)? If not, how can I add an if statement into the TensorFlow calculation graph?

推荐答案

你是正确的 if 语句不起作用这里,因为条件是在图形构造时评估的,而大概是您希望条件依赖于在运行时提供给占位符的值。 (事实上​​,它总是需要第一个分支,因为 condition> 0 计算为 Tensor ,这是在Python中truthy。)

You're correct that the if statement doesn't work here, because the condition is evaluated at graph construction time, whereas presumably you want the condition to depend on the value fed to the placeholder at runtime. (In fact, it will always take the first branch, because condition > 0 evaluates to a Tensor, which is "truthy" in Python.)

为了支持条件控制流程,TensorFlow提供了 tf.cond() 运算符,它根据布尔条件计算两个分支之一。为了告诉你如何使用它,我将重写你的程序,以便 condition 是标量 tf.int32 值为简单起见:

To support conditional control flow, TensorFlow provides the tf.cond() operator, which evaluates one of two branches, depending on a boolean condition. To show you how to use it, I'll rewrite your program so that condition is a scalar tf.int32 value for simplicity:

x = tf.placeholder(tf.float32, shape=[None, ins_size**2*3], name="x_input")
condition = tf.placeholder(tf.int32, shape=[], name="condition")
W = tf.Variable(tf.zeros([ins_size**2 * 3, label_option]), name="weights")
b = tf.Variable(tf.zeros([label_option]), name="bias")

y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b)

这篇关于如何在TensorFlow图中添加if条件?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆