如何使用tf.cond进行批处理 [英] How to use tf.cond for batch processing

查看:102
本文介绍了如何使用tf.cond进行批处理的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想使用 tf.cond(pred,fn1,fn2,name =无)进行条件分支。假设我有两个张量: x,y 。每个张量都是0/1的批次,我想使用此张量压缩 x< y 作为
的来源 tf.cond pred 参数:


pred:确定是否返回fn1或fn2结果的标量。


但是如果我使用然后看起来好像我需要遍历图内的源张量,为每个项目批量制作切片,并对每个项目应用tf.cond。对我来说可疑。为什么tf.cond不接受批处理而仅接受标量?

解决方案

tf.where 听起来像您想要的:张量之间的向量化选择。



tf.cond 是一个控制流修饰符:它确定要执行的操作,因此很难想到有用的批处理语义。 / p>

我们还可以将这些操作混合在一起:根据条件进行切片并将这些切片传递到两个分支的操作​​。

 从tensorflow.python.util导入tensorflow作为tf 
导入嵌套

def slicing_where(condition,full_input,true_branch,false_branch) :
根据条件将`full_input`拆分为`true_branch`和`false_branch`。

Args:
条件:形状为[B_1,。的布尔张量。 ..,B_N]。
full_input:张量或张量的嵌套元组o如果每个dtype的形状为
[B_1,...,B_N,...],则根据条件将其分为true_branch和
false_branch。
true_branch:一个带有单个参数的函数,该参数的
结构和批处理维数与full_input相同。接收与
条件的True条目相对应的
个全输入切片。返回张量或张量的嵌套元组,每个张量具有匹配其输入的批处理
尺寸。
false_branch:类似于true_branch,但是接收与condition的
false元素相对应的输入。返回张量
的张量或嵌套元组(结构与true_branch的返回值相同),但批处理尺寸
与其输入匹配。
返回:
来自true_branch和false_branch的交错输出,每个张量
的形状为[B_1,...,B_N,...]。

full_input_flat = nest.flatten(full_input)
true_indices = tf.where(条件)
false_indices = tf.where(tf.logical_not(条件))
true_branch_inputs = nest.pack_sequence_as(
结构= full_input,
flat_sequence = [tf.gather_nd(params = input_tensor,index = true_indices)
for input_tensor in full_input_flat])
false_branch_inputs = nest.pack_sequence_as(
structure = full_input,
flat_sequence = [tf.gather_nd(params = input_tensor,index = false_indices)
for full_input_flat中的input_tensor])
true_outputs = true_branch( true_branch_inputs)
false_outputs = false_branch(false_branch_inputs)
nest.assert_same_structure(true_outputs,false_outputs)
def scatter_outputs(true_output,false_output):
batch_shape = tf.shape(condition)
分散形状= tf.concat(
[batch_shape,tf.shape(true_output)[tf.rank(batch_sha pe)​​:]],
0)
true_scatter = tf.scatter_nd(
index = tf.cast(true_indices,tf.int32),
updates = true_output,
shape = scattered_shape)
false_scatter = tf.scatter_nd(
index = tf.cast(false_indices,tf.int32),
updates = false_output,
shape = scattered_shape)
返回true_scatter + false_scatter
结果= nest.pack_sequence_as(
结构= true_outputs,
flat_sequence = [
scatter_outputs(true_single_output,false_single_output)
表示true_single_output,false_single_output
in zip(nest.flatten(true_outputs),nest.flatten(false_outputs))])
返回结果

一些示例:

  vector_test = slicing_where(
condition = tf.equal(tf。 range(10)%2,0),
full_input = tf.range(10,dtype = tf.float32),
true_branch = lambda x:0.2 + x,
false_branch = lambda x:0.1 + x)

cross_range =(tf.range(10,dtype = tf.float32)[:, None]
* tf.range(10,dtype = tf .float32)[None,:])
matrix_test = slicing_where(
condition = tf.equal(tf.range(10)%3,0),
full_input = cross_range,
true_branch = lambda x:-x,
false_branch = lambda x:x + 0.1)

with tf.Session():
print(vector_test.eval())
print(matrix_test.eval())

打印:

  [0.2 1.10000002 2.20000005 3.0999999 4.19999981 5.0999999 
6.19999981 7.0999999 8.19999981 9.10000038]
[[0. 0. 0. 0. 0. 0。
0. 0. 0. 0.]
[0.1 1.10000002 2.0999999 3.0999999 4.0999999
5.0999999 6.0999999 7.0999999 8.10000038 9.10000038]
[0.1 2.0999999 4.0999999 6.0999999 8.10000038
10.10 000038 12.10000038 14.10000038 16.10000038 18.10000038]
[0. -3。 -6。 -9。 -12。 -15。
-18。 -21。 -24。 -27。 ]
[0.1 4.0999999 8.10000038 12.10000038 16.10000038
20.10000038 24.10000038 28.10000038 32.09999847 36.09999847]
[0.1 5.0999999 10.10000038 15.10000038 20.10000038
25.10000038 30.10000038 35.09999847 40.09999847 45.09999847]
[0.-6 。 -12。 -18。 -24。 -30。
-36。 -42。 -48。 -54。 ]
[0.1 7.0999999 14.10000038 21.10000038 28.10000038
35.09999847 42.09999847 49.09999847 56.09999847 63.09999847]
[0.1 8.10000038 16.10000038 24.10000038 32.09999847
40.09999847 48.09999847 56.09999847 64.09999847 72.09999847]
[0。 。 -18。 -27。 -36。 -45。
-54。 -63。 -72。 -81。 ]]


I want to use tf.cond(pred, fn1, fn2, name=None) for conditional branching. Let say I have two tensors: x, y. Each tensor is a batch of 0/1 and I want to use this tensors compression x < y as the source for tf.cond pred argument:

pred: A scalar determining whether to return the result of fn1 or fn2.

But if I am working with batches then it looks like I need to iterate over the source tensor inside the graph and make slices for every item in batch and apply tf.cond for every item. Looks suspiciously as for me. Why tf.cond not accept batch and only scalar? Can you advise what is the right way to use it with batch?

解决方案

tf.where sounds like what you want: a vectorized selection between Tensors.

tf.cond is a control flow modifier: it determines which ops are executed, and so it's difficult to think of useful batch semantics.

We can also put together a mixture of these operations: an operation which slices based on a condition and passes those slices to two branches.

import tensorflow as tf
from tensorflow.python.util import nest

def slicing_where(condition, full_input, true_branch, false_branch):
  """Split `full_input` between `true_branch` and `false_branch` on `condition`.

  Args:
    condition: A boolean Tensor with shape [B_1, ..., B_N].
    full_input: A Tensor or nested tuple of Tensors of any dtype, each with
      shape [B_1, ..., B_N, ...], to be split between `true_branch` and
      `false_branch` based on `condition`.
    true_branch: A function taking a single argument, that argument having the
      same structure and number of batch dimensions as `full_input`. Receives
      slices of `full_input` corresponding to the True entries of
      `condition`. Returns a Tensor or nested tuple of Tensors, each with batch
      dimensions matching its inputs.
    false_branch: Like `true_branch`, but receives inputs corresponding to the
      false elements of `condition`. Returns a Tensor or nested tuple of Tensors
      (with the same structure as the return value of `true_branch`), but with
      batch dimensions matching its inputs.
  Returns:
    Interleaved outputs from `true_branch` and `false_branch`, each Tensor
    having shape [B_1, ..., B_N, ...].
  """
  full_input_flat = nest.flatten(full_input)
  true_indices = tf.where(condition)
  false_indices = tf.where(tf.logical_not(condition))
  true_branch_inputs = nest.pack_sequence_as(
      structure=full_input,
      flat_sequence=[tf.gather_nd(params=input_tensor, indices=true_indices)
                     for input_tensor in full_input_flat])
  false_branch_inputs = nest.pack_sequence_as(
      structure=full_input,
      flat_sequence=[tf.gather_nd(params=input_tensor, indices=false_indices)
                     for input_tensor in full_input_flat])
  true_outputs = true_branch(true_branch_inputs)
  false_outputs = false_branch(false_branch_inputs)
  nest.assert_same_structure(true_outputs, false_outputs)
  def scatter_outputs(true_output, false_output):
    batch_shape = tf.shape(condition)
    scattered_shape = tf.concat(
        [batch_shape, tf.shape(true_output)[tf.rank(batch_shape):]],
        0)
    true_scatter = tf.scatter_nd(
        indices=tf.cast(true_indices, tf.int32),
        updates=true_output,
        shape=scattered_shape)
    false_scatter = tf.scatter_nd(
        indices=tf.cast(false_indices, tf.int32),
        updates=false_output,
        shape=scattered_shape)
    return true_scatter + false_scatter
  result = nest.pack_sequence_as(
      structure=true_outputs,
      flat_sequence=[
          scatter_outputs(true_single_output, false_single_output)
          for true_single_output, false_single_output
          in zip(nest.flatten(true_outputs), nest.flatten(false_outputs))])
  return result

Some examples:

vector_test = slicing_where(
    condition=tf.equal(tf.range(10) % 2, 0),
    full_input=tf.range(10, dtype=tf.float32),
    true_branch=lambda x: 0.2 + x,
    false_branch=lambda x: 0.1 + x)

cross_range = (tf.range(10, dtype=tf.float32)[:, None]
               * tf.range(10, dtype=tf.float32)[None, :])
matrix_test = slicing_where(
    condition=tf.equal(tf.range(10) % 3, 0),
    full_input=cross_range,
    true_branch=lambda x: -x,
    false_branch=lambda x: x + 0.1)

with tf.Session():
  print(vector_test.eval())
  print(matrix_test.eval())

Prints:

[ 0.2         1.10000002  2.20000005  3.0999999   4.19999981  5.0999999
  6.19999981  7.0999999   8.19999981  9.10000038]
[[  0.           0.           0.           0.           0.           0.
    0.           0.           0.           0.        ]
 [  0.1          1.10000002   2.0999999    3.0999999    4.0999999
    5.0999999    6.0999999    7.0999999    8.10000038   9.10000038]
 [  0.1          2.0999999    4.0999999    6.0999999    8.10000038
   10.10000038  12.10000038  14.10000038  16.10000038  18.10000038]
 [  0.          -3.          -6.          -9.         -12.         -15.
  -18.         -21.         -24.         -27.        ]
 [  0.1          4.0999999    8.10000038  12.10000038  16.10000038
   20.10000038  24.10000038  28.10000038  32.09999847  36.09999847]
 [  0.1          5.0999999   10.10000038  15.10000038  20.10000038
   25.10000038  30.10000038  35.09999847  40.09999847  45.09999847]
 [  0.          -6.         -12.         -18.         -24.         -30.
  -36.         -42.         -48.         -54.        ]
 [  0.1          7.0999999   14.10000038  21.10000038  28.10000038
   35.09999847  42.09999847  49.09999847  56.09999847  63.09999847]
 [  0.1          8.10000038  16.10000038  24.10000038  32.09999847
   40.09999847  48.09999847  56.09999847  64.09999847  72.09999847]
 [  0.          -9.         -18.         -27.         -36.         -45.
  -54.         -63.         -72.         -81.        ]]

这篇关于如何使用tf.cond进行批处理的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆