numpy.where的更快替代品? [英] faster alternative to numpy.where?

查看:496
本文介绍了numpy.where的更快替代品?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个3d数组,其中填充了从0到N的整数.我需要一个与该数组等于1、2、3,... N的位置相对应的索引列表.我可以使用np.where作为如下:

N = 300
shape = (1000,1000,10)
data = np.random.randint(0,N+1,shape)
indx = [np.where(data == i_id) for i_id in range(1,data.max()+1)]

但这很慢.根据这个问题 快速python numpy的功能在哪里? 应该可以大大加快索引搜索的速度,但是我无法将那里提出的方法转移到我获取实际索引的问题上.加快上述代码的最佳方法是什么?

作为附加组件:我想稍后存储索引,为此,可以使用np.ravel_multi_index来将大小从保存3个索引减少到仅1个,即使用:

indx = [np.ravel_multi_index(np.where(data == i_id), data.shape) for i_id in range(1, data.max()+1)]

更接近例如Matlab的查找功能.可以直接将其合并到不使用np.where的解决方案中吗?

我认为解决此问题的标准向量化方法最终将占用大量内存,对于int64数据,将需要O(8 * N *数据.大小)字节,或者您上面给出的示例的〜22 GB内存.我以为那不是一个选择.

使用稀疏矩阵存储唯一值的位置可能会有所进步.例如:

import numpy as np
from scipy.sparse import csr_matrix

def compute_M(data):
    cols = np.arange(data.size)
    return csr_matrix((cols, (data.ravel(), cols)),
                      shape=(data.max() + 1, data.size))

def get_indices_sparse(data):
    M = compute_M(data)
    return [np.unravel_index(row.data, data.shape) for row in M]

这利用稀疏矩阵构造函数中的快速代码以一种有用的方式来组织数据,从而构造一个稀疏矩阵,其中行i仅包含扁平化数据等于i的索引.

为了进行测试,我还将定义一个执行您简单方法的函数:

def get_indices_simple(data):
    return [np.where(data == i) for i in range(0, data.max() + 1)]

对于相同的输入,这两个函数给出相同的结果:

data_small = np.random.randint(0, 100, size=(100, 100, 10))
all(np.allclose(i1, i2)
    for i1, i2 in zip(get_indices_simple(data_small),
                      get_indices_sparse(data_small)))
# True

稀疏方法比数据集的简单方法快一个数量级:

data = np.random.randint(0, 301, size=(1000, 1000, 10))

%time ind = get_indices_simple(data)
# CPU times: user 14.1 s, sys: 638 ms, total: 14.7 s
# Wall time: 14.8 s

%time ind = get_indices_sparse(data)
# CPU times: user 881 ms, sys: 301 ms, total: 1.18 s
# Wall time: 1.18 s

%time M = compute_M(data)
# CPU times: user 216 ms, sys: 148 ms, total: 365 ms
# Wall time: 363 ms

稀疏方法的另一个好处是,矩阵M最终是一种非常紧凑和有效的方式,用于存储所有相关信息以供以后使用,如问题的附加部分所述.希望有用!


我意识到初始版本中存在一个错误:如果该范围内的任何值都未出现在数据中,则它将失败:这已在上面修复.

I have a 3d array filled with integers from 0 to N. I need a list of the indices corresponding to where the array is equal 1, 2, 3, ... N. I can do it with np.where as follows:

N = 300
shape = (1000,1000,10)
data = np.random.randint(0,N+1,shape)
indx = [np.where(data == i_id) for i_id in range(1,data.max()+1)]

but this is quite slow. According to this question fast python numpy where functionality? it should be possible to speed up the index search quite a lot, but I haven't been able to transfer the methods proposed there to my problem of getting the actual indices. What would be the best way to speed up the above code?

As an add-on: I want to store the indices later, for which it makes sense to use np.ravel_multi_index to reduce the size from saving 3 indices to only 1, i.e. using:

indx = [np.ravel_multi_index(np.where(data == i_id), data.shape) for i_id in range(1, data.max()+1)]

which is closer to e.g. Matlab's find function. Can this be directly incorporated in a solution that doesn't use np.where?

解决方案

I think that a standard vectorized approach to this problem would end up being very memory intensive – for int64 data, it would require O(8 * N * data.size) bytes, or ~22 gigs of memory for the example you gave above. I'm assuming that is not an option.

You might make some progress by using a sparse matrix to store the locations of the unique values. For example:

import numpy as np
from scipy.sparse import csr_matrix

def compute_M(data):
    cols = np.arange(data.size)
    return csr_matrix((cols, (data.ravel(), cols)),
                      shape=(data.max() + 1, data.size))

def get_indices_sparse(data):
    M = compute_M(data)
    return [np.unravel_index(row.data, data.shape) for row in M]

This takes advantage of fast code within the sparse matrix constructor to organize the data in a useful way, constructing a sparse matrix where row i contains just the indices where the flattened data equals i.

To test it out, I'll also define a function that does your straightforward method:

def get_indices_simple(data):
    return [np.where(data == i) for i in range(0, data.max() + 1)]

The two functions give the same results for the same input:

data_small = np.random.randint(0, 100, size=(100, 100, 10))
all(np.allclose(i1, i2)
    for i1, i2 in zip(get_indices_simple(data_small),
                      get_indices_sparse(data_small)))
# True

And the sparse method is an order of magnitude faster than the simple method for your dataset:

data = np.random.randint(0, 301, size=(1000, 1000, 10))

%time ind = get_indices_simple(data)
# CPU times: user 14.1 s, sys: 638 ms, total: 14.7 s
# Wall time: 14.8 s

%time ind = get_indices_sparse(data)
# CPU times: user 881 ms, sys: 301 ms, total: 1.18 s
# Wall time: 1.18 s

%time M = compute_M(data)
# CPU times: user 216 ms, sys: 148 ms, total: 365 ms
# Wall time: 363 ms

The other benefit of the sparse method is that the matrix M ends up being a very compact and efficient way to store all the relevant information for later use, as mentioned in the add-on part of your question. Hope that's useful!


Edit: I realized there was a bug in the initial version: it failed if any values in the range didn't appear in the data: that's now fixed above.

这篇关于numpy.where的更快替代品?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆