在numpy数组上过滤具有特定条件的tensorflow数组 [英] filter tensorflow array with specific condition over numpy array

查看:350
本文介绍了在numpy数组上过滤具有特定条件的tensorflow数组的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个张量流数组名称tf-array和一个numpy数组名称np_array.我想在tf_array中找到关于np-array的特定行.

I have a tensorflow array names tf-array and a numpy array names np_array. I want to find specific rows in tf_array with regards to np-array.

    tf-array = tf.constant(
                [[9.968594,  8.655439,  0.,        0.       ],
                 [0.,        8.3356,    0.,        8.8974   ],
                 [0.,        0.,        6.103182,  7.330564 ],
                 [6.609862,  0.,        3.0614321, 0.       ],
                 [9.497023,  0.,        3.8914037, 0.       ],
                 [0.,        8.457685,  8.602337,  0.       ],
                 [0.,        0.,        5.826657,  8.283971 ]])

我也有一个np数组:

np_array = np.matrix(
 [[2, 5, 1],
  [1, 6, 4],
  [0, 0, 0],
  [2, 3, 6],
  [4, 2, 4]]

现在,我想将元素保留在tf-array中,其中n (here n is 2)的组合(它们的索引)的值是np-array.是什么意思?

Now I want to keep the elements in tf-array in which the combination of n (here n is 2) of them (index of them) is in the value of np-array. What does it mean?

例如,在tf-array的第一列中,具有值的索引为:(0,3,4). np-array中是否有包含以下两个索引的任意组合的行:(0,3), (0,4) or (3,4).实际上,没有这样的行.因此,该列中的所有元素都变为zero.

For example, in tf-array, in the first column, indexes which has value are: (0,3,4). Is there any row in np-array which contains any combination of these two indexes: (0,3), (0,4) or (3,4). Actually, there is no such row. So all the elements in that column became zero.

tf-array中第二列的索引为(0,1) (0,5) (1,5).如您所见,记录(1,5)在第一行的np-array中可用.这就是为什么我们将这些保留在tf-array中的原因.

Indexes for the second column in tf-array is (0,1) (0,5) (1,5). As you see the record (1,5) is available in the np-array in the first row. Thats why we keep those in the tf-array.

所以最终结果应该是这样的:

So the final result should be like this:

[[0.        0.        0.        0.       ]
 [0.        8.3356    0.        8.8974   ]
 [0.        0.        6.103182  7.330564 ]
 [0.        0.        3.0614321 0.       ]
 [0.        0.        3.8914037 0.       ]
 [0.        8.457685  8.602337  0.       ]
 [0.        0.        5.826657  8.283971 ]]

我正在寻找一种非常有效的方法,因为我拥有大量数据.

I am looking for a very efficient approach as I have large number of data.

更新1

我可以使用下面的代码来得到它,该代码为True赋值,并且给false设置了零掩码:

I could get this with the below code which is giving True where there is value and the zero mask to false:

[[ True  True False False]
 [False  True False  True]
 [False False  True  True]
 [ True False  True False]
 [ True False  True False]
 [False  True  True False]
 [False False  True  True]]

with tf.Session() as sess:  
 where = tf.not_equal(tf-array, 0.0)
 print(sess.run(where))

但是我如何将Theese矩阵与np_array进行比较?

But how can I compare theese matrix with np_array?

提前谢谢!

推荐答案

这是 https://stackoverflow.com/a/56510832/7207392 ,并进行了必要的修改.为了简单起见,我将np.array用于所有数据.我不是tensortflow专家,所以如果翻译不完全是直截了当的,您将不得不问别人该怎么做.

Here is the solution from https://stackoverflow.com/a/56510832/7207392 with necessary modifications. For the sake of simplicity I use np.array for all data. I'm no tensortflow expert, so if translating is not entirely straight forward, you'll have to ask somebody else how to do it.

import numpy as np

def f(a1, a2, n):
    N,M = a1.shape
    a1p = np.concatenate([a1,np.zeros((1,a1.shape[1]),a1.dtype)], axis=0)
    a2 = np.sort(a2, axis=1)
    a2[:,1:][a2[:,1:]==a2[:,:-1]] = N
    y,x = np.where(np.count_nonzero(a1p[a2], axis=1) >= n)
    out = np.zeros_like(a1p)
    out[a2[y],x[:,None]] = a1p[a2[y],x[:,None]]
    return out[:-1]

a1 = np.array(
    [[9.968594,  8.655439,  0.,        0.       ],
     [0.,        8.3356,    0.,        8.8974   ],
     [0.,        0.,        6.103182,  7.330564 ],
     [6.609862,  0.,        3.0614321, 0.       ],
     [9.497023,  0.,        3.8914037, 0.       ],
     [0.,        8.457685,  8.602337,  0.       ],
     [0.,        0.,        5.826657,  8.283971 ]])

a2 = np.array(
 [[2, 5, 1],
  [1, 6, 4],
  [0, 0, 0],
  [2, 3, 6],
  [4, 2, 4]])

print(f(a1,a2,2))

输出:

[[0.        0.        0.        0.       ]
 [0.        8.3356    0.        8.8974   ]
 [0.        0.        6.103182  7.330564 ]
 [0.        0.        3.0614321 0.       ]
 [0.        0.        3.8914037 0.       ]
 [0.        8.457685  8.602337  0.       ]
 [0.        0.        5.826657  8.283971 ]]

这篇关于在numpy数组上过滤具有特定条件的tensorflow数组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆