如何使用tf.cond进行批处理 [英] How to use tf.cond for batch processing
问题描述
我想使用 tf.cond(pred,fn1,fn2,name =无)进行条件分支。假设我有两个张量: x,y
。每个张量都是0/1的批次,我想使用此张量压缩 x< y
作为
的来源 tf.cond pred
参数:
pred:确定是否返回fn1或fn2结果的标量。
但是如果我使用然后看起来好像我需要遍历图内的源张量,为每个项目批量制作切片,并对每个项目应用tf.cond。对我来说可疑。为什么tf.cond不接受批处理而仅接受标量?
tf.where 听起来像您想要的:张量之间的向量化选择。
tf.cond
是一个控制流修饰符:它确定要执行的操作,因此很难想到有用的批处理语义。 / p>
我们还可以将这些操作混合在一起:根据条件进行切片并将这些切片传递到两个分支的操作。
从tensorflow.python.util导入tensorflow作为tf
导入嵌套
def slicing_where(condition,full_input,true_branch,false_branch) :
根据条件将`full_input`拆分为`true_branch`和`false_branch`。
Args:
条件:形状为[B_1,。的布尔张量。 ..,B_N]。
full_input:张量或张量的嵌套元组o如果每个dtype的形状为
[B_1,...,B_N,...],则根据条件将其分为true_branch和
false_branch。
true_branch:一个带有单个参数的函数,该参数的
结构和批处理维数与full_input相同。接收与
条件的True条目相对应的
个全输入切片。返回张量或张量的嵌套元组,每个张量具有匹配其输入的批处理
尺寸。
false_branch:类似于true_branch,但是接收与condition的
false元素相对应的输入。返回张量
的张量或嵌套元组(结构与true_branch的返回值相同),但批处理尺寸
与其输入匹配。
返回:
来自true_branch和false_branch的交错输出,每个张量
的形状为[B_1,...,B_N,...]。
full_input_flat = nest.flatten(full_input)
true_indices = tf.where(条件)
false_indices = tf.where(tf.logical_not(条件))
true_branch_inputs = nest.pack_sequence_as(
结构= full_input,
flat_sequence = [tf.gather_nd(params = input_tensor,index = true_indices)
for input_tensor in full_input_flat])
false_branch_inputs = nest.pack_sequence_as(
structure = full_input,
flat_sequence = [tf.gather_nd(params = input_tensor,index = false_indices)
for full_input_flat中的input_tensor])
true_outputs = true_branch( true_branch_inputs)
false_outputs = false_branch(false_branch_inputs)
nest.assert_same_structure(true_outputs,false_outputs)
def scatter_outputs(true_output,false_output):
batch_shape = tf.shape(condition)
分散形状= tf.concat(
[batch_shape,tf.shape(true_output)[tf.rank(batch_sha pe):]],
0)
true_scatter = tf.scatter_nd(
index = tf.cast(true_indices,tf.int32),
updates = true_output,
shape = scattered_shape)
false_scatter = tf.scatter_nd(
index = tf.cast(false_indices,tf.int32),
updates = false_output,
shape = scattered_shape)
返回true_scatter + false_scatter
结果= nest.pack_sequence_as(
结构= true_outputs,
flat_sequence = [
scatter_outputs(true_single_output,false_single_output)
表示true_single_output,false_single_output
in zip(nest.flatten(true_outputs),nest.flatten(false_outputs))])
返回结果
一些示例:
vector_test = slicing_where(
condition = tf.equal(tf。 range(10)%2,0),
full_input = tf.range(10,dtype = tf.float32),
true_branch = lambda x:0.2 + x,
false_branch = lambda x:0.1 + x)
cross_range =(tf.range(10,dtype = tf.float32)[:, None]
* tf.range(10,dtype = tf .float32)[None,:])
matrix_test = slicing_where(
condition = tf.equal(tf.range(10)%3,0),
full_input = cross_range,
true_branch = lambda x:-x,
false_branch = lambda x:x + 0.1)
with tf.Session():
print(vector_test.eval())
print(matrix_test.eval())
打印:
[0.2 1.10000002 2.20000005 3.0999999 4.19999981 5.0999999
6.19999981 7.0999999 8.19999981 9.10000038]
[[0. 0. 0. 0. 0. 0。
0. 0. 0. 0.]
[0.1 1.10000002 2.0999999 3.0999999 4.0999999
5.0999999 6.0999999 7.0999999 8.10000038 9.10000038]
[0.1 2.0999999 4.0999999 6.0999999 8.10000038
10.10 000038 12.10000038 14.10000038 16.10000038 18.10000038]
[0. -3。 -6。 -9。 -12。 -15。
-18。 -21。 -24。 -27。 ]
[0.1 4.0999999 8.10000038 12.10000038 16.10000038
20.10000038 24.10000038 28.10000038 32.09999847 36.09999847]
[0.1 5.0999999 10.10000038 15.10000038 20.10000038
25.10000038 30.10000038 35.09999847 40.09999847 45.09999847]
[0.-6 。 -12。 -18。 -24。 -30。
-36。 -42。 -48。 -54。 ]
[0.1 7.0999999 14.10000038 21.10000038 28.10000038
35.09999847 42.09999847 49.09999847 56.09999847 63.09999847]
[0.1 8.10000038 16.10000038 24.10000038 32.09999847
40.09999847 48.09999847 56.09999847 64.09999847 72.09999847]
[0。 。 -18。 -27。 -36。 -45。
-54。 -63。 -72。 -81。 ]]
I want to use tf.cond(pred, fn1, fn2, name=None) for conditional branching. Let say I have two tensors: x, y
. Each tensor is a batch of 0/1 and I want to use this tensors compression x < y
as the source for
tf.cond pred
argument:
pred: A scalar determining whether to return the result of fn1 or fn2.
But if I am working with batches then it looks like I need to iterate over the source tensor inside the graph and make slices for every item in batch and apply tf.cond for every item. Looks suspiciously as for me. Why tf.cond not accept batch and only scalar? Can you advise what is the right way to use it with batch?
tf.where sounds like what you want: a vectorized selection between Tensors.
tf.cond
is a control flow modifier: it determines which ops are executed, and so it's difficult to think of useful batch semantics.
We can also put together a mixture of these operations: an operation which slices based on a condition and passes those slices to two branches.
import tensorflow as tf
from tensorflow.python.util import nest
def slicing_where(condition, full_input, true_branch, false_branch):
"""Split `full_input` between `true_branch` and `false_branch` on `condition`.
Args:
condition: A boolean Tensor with shape [B_1, ..., B_N].
full_input: A Tensor or nested tuple of Tensors of any dtype, each with
shape [B_1, ..., B_N, ...], to be split between `true_branch` and
`false_branch` based on `condition`.
true_branch: A function taking a single argument, that argument having the
same structure and number of batch dimensions as `full_input`. Receives
slices of `full_input` corresponding to the True entries of
`condition`. Returns a Tensor or nested tuple of Tensors, each with batch
dimensions matching its inputs.
false_branch: Like `true_branch`, but receives inputs corresponding to the
false elements of `condition`. Returns a Tensor or nested tuple of Tensors
(with the same structure as the return value of `true_branch`), but with
batch dimensions matching its inputs.
Returns:
Interleaved outputs from `true_branch` and `false_branch`, each Tensor
having shape [B_1, ..., B_N, ...].
"""
full_input_flat = nest.flatten(full_input)
true_indices = tf.where(condition)
false_indices = tf.where(tf.logical_not(condition))
true_branch_inputs = nest.pack_sequence_as(
structure=full_input,
flat_sequence=[tf.gather_nd(params=input_tensor, indices=true_indices)
for input_tensor in full_input_flat])
false_branch_inputs = nest.pack_sequence_as(
structure=full_input,
flat_sequence=[tf.gather_nd(params=input_tensor, indices=false_indices)
for input_tensor in full_input_flat])
true_outputs = true_branch(true_branch_inputs)
false_outputs = false_branch(false_branch_inputs)
nest.assert_same_structure(true_outputs, false_outputs)
def scatter_outputs(true_output, false_output):
batch_shape = tf.shape(condition)
scattered_shape = tf.concat(
[batch_shape, tf.shape(true_output)[tf.rank(batch_shape):]],
0)
true_scatter = tf.scatter_nd(
indices=tf.cast(true_indices, tf.int32),
updates=true_output,
shape=scattered_shape)
false_scatter = tf.scatter_nd(
indices=tf.cast(false_indices, tf.int32),
updates=false_output,
shape=scattered_shape)
return true_scatter + false_scatter
result = nest.pack_sequence_as(
structure=true_outputs,
flat_sequence=[
scatter_outputs(true_single_output, false_single_output)
for true_single_output, false_single_output
in zip(nest.flatten(true_outputs), nest.flatten(false_outputs))])
return result
Some examples:
vector_test = slicing_where(
condition=tf.equal(tf.range(10) % 2, 0),
full_input=tf.range(10, dtype=tf.float32),
true_branch=lambda x: 0.2 + x,
false_branch=lambda x: 0.1 + x)
cross_range = (tf.range(10, dtype=tf.float32)[:, None]
* tf.range(10, dtype=tf.float32)[None, :])
matrix_test = slicing_where(
condition=tf.equal(tf.range(10) % 3, 0),
full_input=cross_range,
true_branch=lambda x: -x,
false_branch=lambda x: x + 0.1)
with tf.Session():
print(vector_test.eval())
print(matrix_test.eval())
Prints:
[ 0.2 1.10000002 2.20000005 3.0999999 4.19999981 5.0999999
6.19999981 7.0999999 8.19999981 9.10000038]
[[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[ 0.1 1.10000002 2.0999999 3.0999999 4.0999999
5.0999999 6.0999999 7.0999999 8.10000038 9.10000038]
[ 0.1 2.0999999 4.0999999 6.0999999 8.10000038
10.10000038 12.10000038 14.10000038 16.10000038 18.10000038]
[ 0. -3. -6. -9. -12. -15.
-18. -21. -24. -27. ]
[ 0.1 4.0999999 8.10000038 12.10000038 16.10000038
20.10000038 24.10000038 28.10000038 32.09999847 36.09999847]
[ 0.1 5.0999999 10.10000038 15.10000038 20.10000038
25.10000038 30.10000038 35.09999847 40.09999847 45.09999847]
[ 0. -6. -12. -18. -24. -30.
-36. -42. -48. -54. ]
[ 0.1 7.0999999 14.10000038 21.10000038 28.10000038
35.09999847 42.09999847 49.09999847 56.09999847 63.09999847]
[ 0.1 8.10000038 16.10000038 24.10000038 32.09999847
40.09999847 48.09999847 56.09999847 64.09999847 72.09999847]
[ 0. -9. -18. -27. -36. -45.
-54. -63. -72. -81. ]]
这篇关于如何使用tf.cond进行批处理的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!