从Numpy数组的索引中采样的有效方法? [英] Efficient way of sampling from indices of a Numpy array?

查看:407
本文介绍了从Numpy数组的索引中采样的有效方法?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想从2D Numpy数组的索引中进行采样,考虑到每个索引都由该数组内部的数字加权.我知道的方式是使用numpy.random.choice,但是它不会返回索引,而是返回数字本身.有什么有效的方法吗?

I'd like to sample from indices of a 2D Numpy array, considering that each index is weighted by the number inside of that array. The way I know it is with numpy.random.choice however that does not return the index but the number itself. Is there any efficient way of doing so?

这是我的代码:

import numpy as np
A=np.arange(1,10).reshape(3,3)
A_flat=A.flatten()
d=np.random.choice(A_flat,size=10,p=A_flat/float(np.sum(A_flat)))
print d

推荐答案

您可以执行以下操作:

import numpy as np

def wc(weights):
    cs = np.cumsum(weights)
    idx = cs.searchsorted(np.random.random() * cs[-1], 'right')
    return np.unravel_index(idx, weights.shape)

请注意,求和是最慢的部分,因此,如果需要对同一数组重复执行此操作,建议您提前计算并重新使用求和.

Notice that the cumsum is the slowest part of this, so if you need to do this repeatidly for the same array I'd suggest computing the cumsum ahead of time and reusing it.

这篇关于从Numpy数组的索引中采样的有效方法?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆