张量中间维度的散点运算 [英] Scatter operation for middle dimension of a tensor

查看:28
本文介绍了张量中间维度的散点运算的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个 3d 张量,我需要在其中保留第二维中某些位置的向量,并将剩余的向量归零.这些位置被指定为一维数组.我认为最好的方法是将张量与二进制掩码相乘.

I have a 3d tensor where I need to preserve vectors at certain positions in the second dimension, and zero out the remaining vectors. The positions are specified as a 1d array. I'm thinking the best way to do this is to multiply the tensor with a binary mask.

这是一个简单的 Numpy 版本:

Here's a simple Numpy version:

A.shape: (b, n, m) 
indices.shape: (b)

mask = np.zeros(A.shape)
for i in range(b):
  mask[i][indices[i]] = 1
result = A*mask

因此,对于 A 中的每个 nxm 矩阵,我需要保留由索引指定的行,并将其余行归零.

So for each nxm matrix in A, I need to preserve rows specified by indices, and zero out the rest.

我正在尝试使用 tf.scatter_nd 在 TensorFlow 中执行此操作操作,但我无法弄清楚索引的正确形状:

I'm trying to do this in TensorFlow using tf.scatter_nd op, but I can't figure out the correct shape of indices:

shape = tf.constant([3,5,4])
A = tf.random_normal(shape)       
indices = tf.constant([2,1,4])   #???   
updates = tf.ones((3,4))           
mask = tf.scatter_nd(indices, updates, shape) 
result = A*mask

推荐答案

这是一种方法,创建一个掩码并使用 tf.where:

Here's one way to do it, creating a mask and using tf.where:

import tensorflow as tf
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()

shape = tf.constant([3,5,4])
A = tf.random_normal(shape)

array_shape = tf.shape(A)
indices = tf.constant([2,1,4])
non_zero_indices = tf.stack((tf.range(array_shape[0]), indices), axis=1)
should_keep_row = tf.scatter_nd(non_zero_indices, tf.ones_like(indices),
                                shape=[array_shape[0], array_shape[1]])
print("should_keep_row", should_keep_row)
masked = tf.where(tf.cast(tf.tile(should_keep_row[:, :, None],
                                  [1, 1, array_shape[2]]), tf.bool),
                   A,
                   tf.zeros_like(A))
print("masked", masked)

打印:

should_keep_row tf.Tensor(
[[0 0 1 0 0]
 [0 1 0 0 0]
 [0 0 0 0 1]], shape=(3, 5), dtype=int32)
masked tf.Tensor(
[[[ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]
  [ 0.02036316 -0.07163608 -3.16707373  1.31406844]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]]

 [[ 0.          0.          0.          0.        ]
  [-0.76696759 -0.28313264  0.87965059 -1.28844094]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]]

 [[ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]
  [ 1.03188455  0.44305769  0.71291149  1.59758031]]], shape=(3, 5, 4), dtype=float32)

(该示例使用了急切执行,但同样的操作也适用于会话中的图形执行)

(The example is using eager execution, but the same ops will work with graph execution in a Session)

这篇关于张量中间维度的散点运算的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆