如何在 TensorFlow 图中添加 if 条件? [英] How to add if condition in a TensorFlow graph?

查看:35
本文介绍了如何在 TensorFlow 图中添加 if 条件?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设我有以下代码:

x = tf.placeholder("float32", shape=[None, ins_size**2*3], name = "x_input")
condition = tf.placeholder("int32", shape=[1, 1], name = "condition")
W = tf.Variable(tf.zeros([ins_size**2*3,label_option]), name = "weights")
b = tf.Variable(tf.zeros([label_option]), name = "bias")

if condition > 0:
    y = tf.nn.softmax(tf.matmul(x, W) + b)
else:
    y = tf.nn.softmax(tf.matmul(x, W) - b)  

if 语句是否在计算中起作用(我不这么认为)?如果没有,如何在 TensorFlow 计算图中添加 if 语句?

Would the if statement work in the calculation (I do not think so)? If not, how can I add an if statement into the TensorFlow calculation graph?

推荐答案

if 语句在这里不起作用是正确的,因为条件是在图构建时评估的,而大概您希望条件取决于在运行时提供给占位符的值.(实际上,它总是采用第一个分支,因为 condition > 0 计算结果为 Tensor,即 Python 中的真实".)

You're correct that the if statement doesn't work here, because the condition is evaluated at graph construction time, whereas presumably you want the condition to depend on the value fed to the placeholder at runtime. (In fact, it will always take the first branch, because condition > 0 evaluates to a Tensor, which is "truthy" in Python.)

为了支持条件控制流,TensorFlow 提供了tf.cond() 运算符,根据布尔条件计算两个分支之一.为了向您展示如何使用它,为了简单起见,我将重写您的程序,以便 condition 是一个标量 tf.int32 值:

To support conditional control flow, TensorFlow provides the tf.cond() operator, which evaluates one of two branches, depending on a boolean condition. To show you how to use it, I'll rewrite your program so that condition is a scalar tf.int32 value for simplicity:

x = tf.placeholder(tf.float32, shape=[None, ins_size**2*3], name="x_input")
condition = tf.placeholder(tf.int32, shape=[], name="condition")
W = tf.Variable(tf.zeros([ins_size**2 * 3, label_option]), name="weights")
b = tf.Variable(tf.zeros([label_option]), name="bias")

y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b)

这篇关于如何在 TensorFlow 图中添加 if 条件?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆