过滤张量中的非零值 [英] Filter out non-zero values in a tensor

查看:150
本文介绍了过滤张量中的非零值的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设我有一个数组:input = np.array([[1,0,3,5,0,8,6]]),我想过滤出[1,3,5,8,6].

Suppose I have an array: input = np.array([[1,0,3,5,0,8,6]]), and I want to filter out [1,3,5,8,6].

我知道您可以在条件下使用tf.where,但是返回的值中仍然有0.以下代码段的输出为[[[1 0 3 5 0 8 6]]].我也不明白为什么tf.where同时需要xy.

I know that you can use tf.where with a condition but the returned value still has 0's in it. Output of the following snippet is [[[1 0 3 5 0 8 6]]]. I also don't understand why tf.where needs both x and y.

反正我可以摆脱结果张量中的0吗?

Is there anyway I can get rid of the 0's in the resulting tensor?

import numpy as np
import tensorflow as tf

input = np.array([[1,0,3,5,0,8,6]])

X = tf.placeholder(tf.int32,[None,7])

zeros = tf.zeros_like(X)
index = tf.not_equal(X,zeros)
loc = tf.where(index,x=X,y=X)

with tf.Session() as sess:
    out = sess.run([loc],feed_dict={X:input})
    print np.array(out)

推荐答案

首先创建一个布尔掩码,以标识您的条件为真;然后将遮罩应用于张量,如下所示.您可以使用tf.where进行索引-但是它使用x& y返回一个张量,其张数与输入的等级相同,因此,如果不做进一步的工作,您可以达到的最佳效果将类似于[[[1 -1 3 5- 1 8 6]]]将-1更改为您以后识别的内容.仅使用where(不带x& y)将为您提供条件为真的所有值的索引,因此如果您愿意,可以使用索引来创建解决方案.为了最清楚起见,我的建议如下.

First create a boolean mask to identify where your condition is true; then apply the mask to your tensor, as shown below. You can if you want use tf.where to index - however it returns a tensor using x&y with the same rank as the input so without further work the best you could achieve would be something like [[[1 -1 3 5 -1 8 6]]] changing -1 with something that you would identify to remove later. Just using where (without x&y) will give you the index of all values where your condition is true so a solution can be created using indexes if that is what you prefer. My recommendation is below for the most clarity.

import numpy as np
import tensorflow as tf
input = np.array([[1,0,3,5,0,8,6]])
X = tf.placeholder(tf.int32,[None,7])
zeros = tf.cast(tf.zeros_like(X),dtype=tf.bool)
ones = tf.cast(tf.ones_like(X),dtype=tf.bool)
loc = tf.where(input!=0,ones,zeros)
result=tf.boolean_mask(input,loc)
with tf.Session() as sess:
 out = sess.run([result],feed_dict={X:input})
 print (np.array(out))

这篇关于过滤张量中的非零值的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆