在 TensorFlow 图中使用 if 条件 [英] Using if conditions inside a TensorFlow graph

查看:29
本文介绍了在 TensorFlow 图中使用 if 条件的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在 tensorflow CIFAR-10 教程cifar10_inputs.py 行174 据说你应该随机化操作 random_contrast 和 random_brightness 的顺序以获得更好的数据增强.

In tensorflow CIFAR-10 tutorial in cifar10_inputs.py line 174 it is said you should randomize the order of the operations random_contrast and random_brightness for better data augmentation.

为此,我想到的第一件事是从 0 和 1 之间的均匀分布中绘制一个随机变量:p_order.然后做:

To do so the first thing I think of is drawing a random variable from the uniform distribution between 0 and 1 : p_order. And do:

if p_order>0.5:
  distorted_image=tf.image.random_contrast(image)
  distorted_image=tf.image.random_brightness(distorted_image)
else:
  distorted_image=tf.image.random_brightness(image)
  distorted_image=tf.image.random_contrast(distorted_image)

然而,有两种可能的方式来获取 p_order:

However there are two possible options for getting p_order:

1) 使用 numpy 这让我不满意,因为我想要纯 TF 并且 TF 不鼓励其用户混合使用 numpy 和 tensorflow

1) Using numpy which disatisfies me as I wanted pure TF and that TF discourages its user to mix numpy and tensorflow

2) 使用 TF,但是因为 p_order 只能在 tf.Session() 中计算我真的不知道我是否应该这样做:

2) Using TF, however as p_order can only be evaluated in a tf.Session() I do not really know if I should do:

with tf.Session() as sess2:
  p_order_tensor=tf.random_uniform([1,],0.,1.)
  p_order=float(p_order_tensor.eval())

所有这些操作都在函数体内,并从另一个具有不同会话/图形的脚本中运行.或者我可以将另一个脚本中的图形作为参数传递给这个函数,但我很困惑.甚至像这样的 tensorflow 函数或推理这样的事实似乎以全局方式定义了图形,而没有明确地将其作为输出返回,这对我来说有点难以理解.

All those operations are inside the body of a function and are run from another script which has a different session/graph. Or I could pass the graph from the other script as an argument to this function but I am confused. Even the fact that tensorflow functions like this one or inference for example seem to define the graph in a global fashion without explicitly returning it as an output is a bit hard to understand for me.

推荐答案

您可以使用 tf.cond(pred, fn1, fn2, name=None) (参见文档).此函数允许您在 TensorFlow 图内使用 pred 的布尔值(无需调用 self.eval()sess.run(),因此不需要会话).

You can use tf.cond(pred, fn1, fn2, name=None) (see doc). This function allows you to use the boolean value of pred inside the TensorFlow graph (no need to call self.eval() or sess.run(), hence no need of a Session).

这是一个如何使用它的例子:

Here is an example of how to use it:

def fn1():
    distorted_image=tf.image.random_contrast(image)
    distorted_image=tf.image.random_brightness(distorted_image)
    return distorted_image
def fn2():
    distorted_image=tf.image.random_brightness(image)
    distorted_image=tf.image.random_contrast(distorted_image)
    return distorted_image

# Uniform variable in [0,1)
p_order = tf.random_uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)
pred = tf.less(p_order, 0.5)

distorted_image = tf.cond(pred, fn1, fn2)

这篇关于在 TensorFlow 图中使用 if 条件的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆