张量中间维度的散点运算 [英] Scatter operation for middle dimension of a tensor
问题描述
我有一个 3d 张量,我需要在其中保留第二维中某些位置的向量,并将剩余的向量归零.这些位置被指定为一维数组.我认为最好的方法是将张量与二进制掩码相乘.
I have a 3d tensor where I need to preserve vectors at certain positions in the second dimension, and zero out the remaining vectors. The positions are specified as a 1d array. I'm thinking the best way to do this is to multiply the tensor with a binary mask.
这是一个简单的 Numpy 版本:
Here's a simple Numpy version:
A.shape: (b, n, m)
indices.shape: (b)
mask = np.zeros(A.shape)
for i in range(b):
mask[i][indices[i]] = 1
result = A*mask
因此,对于 A 中的每个 nxm 矩阵,我需要保留由索引指定的行,并将其余行归零.
So for each nxm matrix in A, I need to preserve rows specified by indices, and zero out the rest.
我正在尝试使用 tf.scatter_nd 在 TensorFlow 中执行此操作操作,但我无法弄清楚索引的正确形状:
I'm trying to do this in TensorFlow using tf.scatter_nd op, but I can't figure out the correct shape of indices:
shape = tf.constant([3,5,4])
A = tf.random_normal(shape)
indices = tf.constant([2,1,4]) #???
updates = tf.ones((3,4))
mask = tf.scatter_nd(indices, updates, shape)
result = A*mask
推荐答案
这是一种方法,创建一个掩码并使用 tf.where
:
Here's one way to do it, creating a mask and using tf.where
:
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()
shape = tf.constant([3,5,4])
A = tf.random_normal(shape)
array_shape = tf.shape(A)
indices = tf.constant([2,1,4])
non_zero_indices = tf.stack((tf.range(array_shape[0]), indices), axis=1)
should_keep_row = tf.scatter_nd(non_zero_indices, tf.ones_like(indices),
shape=[array_shape[0], array_shape[1]])
print("should_keep_row", should_keep_row)
masked = tf.where(tf.cast(tf.tile(should_keep_row[:, :, None],
[1, 1, array_shape[2]]), tf.bool),
A,
tf.zeros_like(A))
print("masked", masked)
打印:
should_keep_row tf.Tensor(
[[0 0 1 0 0]
[0 1 0 0 0]
[0 0 0 0 1]], shape=(3, 5), dtype=int32)
masked tf.Tensor(
[[[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0.02036316 -0.07163608 -3.16707373 1.31406844]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]
[[ 0. 0. 0. 0. ]
[-0.76696759 -0.28313264 0.87965059 -1.28844094]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]
[[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 1.03188455 0.44305769 0.71291149 1.59758031]]], shape=(3, 5, 4), dtype=float32)
(该示例使用了急切执行,但同样的操作也适用于会话中的图形执行)
(The example is using eager execution, but the same ops will work with graph execution in a Session)
这篇关于张量中间维度的散点运算的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!