过滤张量中的非零值 [英] Filter out non-zero values in a tensor
问题描述
假设我有一个数组:input = np.array([[1,0,3,5,0,8,6]])
,我想过滤出[1,3,5,8,6]
.
Suppose I have an array: input = np.array([[1,0,3,5,0,8,6]])
, and I want to filter out [1,3,5,8,6]
.
我知道您可以在条件下使用tf.where
,但是返回的值中仍然有0.以下代码段的输出为[[[1 0 3 5 0 8 6]]]
.我也不明白为什么tf.where
同时需要x
和y
.
I know that you can use tf.where
with a condition but the returned value still has 0's in it. Output of the following snippet is [[[1 0 3 5 0 8 6]]]
. I also don't understand why tf.where
needs both x
and y
.
反正我可以摆脱结果张量中的0吗?
Is there anyway I can get rid of the 0's in the resulting tensor?
import numpy as np
import tensorflow as tf
input = np.array([[1,0,3,5,0,8,6]])
X = tf.placeholder(tf.int32,[None,7])
zeros = tf.zeros_like(X)
index = tf.not_equal(X,zeros)
loc = tf.where(index,x=X,y=X)
with tf.Session() as sess:
out = sess.run([loc],feed_dict={X:input})
print np.array(out)
推荐答案
首先创建一个布尔掩码,以标识您的条件为真;然后将遮罩应用于张量,如下所示.您可以使用tf.where进行索引-但是它使用x& y返回一个张量,其张数与输入的等级相同,因此,如果不做进一步的工作,您可以达到的最佳效果将类似于[[[1 -1 3 5- 1 8 6]]]将-1更改为您以后识别的内容.仅使用where(不带x& y)将为您提供条件为真的所有值的索引,因此如果您愿意,可以使用索引来创建解决方案.为了最清楚起见,我的建议如下.
First create a boolean mask to identify where your condition is true; then apply the mask to your tensor, as shown below. You can if you want use tf.where to index - however it returns a tensor using x&y with the same rank as the input so without further work the best you could achieve would be something like [[[1 -1 3 5 -1 8 6]]] changing -1 with something that you would identify to remove later. Just using where (without x&y) will give you the index of all values where your condition is true so a solution can be created using indexes if that is what you prefer. My recommendation is below for the most clarity.
import numpy as np
import tensorflow as tf
input = np.array([[1,0,3,5,0,8,6]])
X = tf.placeholder(tf.int32,[None,7])
zeros = tf.cast(tf.zeros_like(X),dtype=tf.bool)
ones = tf.cast(tf.ones_like(X),dtype=tf.bool)
loc = tf.where(input!=0,ones,zeros)
result=tf.boolean_mask(input,loc)
with tf.Session() as sess:
out = sess.run([result],feed_dict={X:input})
print (np.array(out))
这篇关于过滤张量中的非零值的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!