np.add.at与数组建立索引 [英] np.add.at indexing with array

查看:161
本文介绍了np.add.at与数组建立索引的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用cs231n,并且在理解此索引的工作方式方面遇到了困难.鉴于

I'm working on cs231n and I'm having a difficult time understanding how this indexing works. Given that

x = [[0,4,1], [3,2,4]]
dW = np.zeros(5,6)
dout = [[[  1.19034710e-01  -4.65005990e-01   8.93743168e-01  -9.78047129e-01
            -8.88672957e-01  -4.66605091e-01]
         [ -1.38617461e-03  -2.64569728e-01  -3.83712733e-01  -2.61360826e-01
            8.07072009e-01  -5.47607277e-01]
         [ -3.97087458e-01  -4.25187949e-02   2.57931759e-01   7.49565950e-01
           1.37707667e+00   1.77392240e+00]]

       [[ -1.20692745e+00  -8.28111550e-01   6.53041092e-01  -2.31247762e+00
         -1.72370321e+00   2.44308033e+00]
        [ -1.45191870e+00  -3.49328154e-01   6.15445782e-01  -2.84190582e-01
           4.85997687e-02   4.81590106e-01]
        [ -1.14828583e+00  -9.69055406e-01  -1.00773809e+00   3.63553835e-01
          -1.28078363e+00  -2.54448436e+00]]]

他们所做的操作是

np.add.at(dW, x, dout)

x是一个二维数组.索引在这里如何工作?我浏览了np.ufunc.at文档,但是他们有带有1d数组和常量的简单示例:

x is a two dimensional array. How does indexing work here? I went through np.ufunc.at documentation but they have simple examples with 1d array and constant:

np.add.at(a, [0, 1, 2, 2], 1)

推荐答案

In [226]: x = [[0,4,1], [3,2,4]]
     ...: dW = np.zeros((5,6),int)

In [227]: np.add.at(dW,x,1)
In [228]: dW
Out[228]: 
array([[0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 1, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 0, 0]])

使用此x时,没有任何重复的条目,因此add.at与使用+=索引相同.等效地,我们可以通过以下方式读取更改的值:

With this x there aren't any duplicate entries, so add.at is the same as using += indexing. Equivalently we can read the changed values with:

In [229]: dW[x[0], x[1]]
Out[229]: array([1, 1, 1])

两种索引的工作方式相同,包括广播:

The indices work the same either way, including broadcasting:

In [234]: dW[...]=0
In [235]: np.add.at(dW,[[[1],[2]],[2,4,4]],1)
In [236]: dW
Out[236]: 
array([[0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 2, 0],
       [0, 0, 1, 0, 2, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0]])

可能的值

相对于索引,值必须为broadcastable:

In [112]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)))
...
In [114]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)).ravel())
...
ValueError: array is not broadcastable to correct shape
In [115]: np.add.at(dW,[[[1],[2]],[2,4,4]],[1,2,3])

In [117]: np.add.at(dW,[[[1],[2]],[2,4,4]],[[1],[2]])

In [118]: dW
Out[118]: 
array([[ 0,  0,  0,  0,  0,  0],
       [ 0,  0,  3,  0,  9,  0],
       [ 0,  0,  4,  0, 11,  0],
       [ 0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0]])

在这种情况下,索引定义了(2,3)形状,因此(2,3),(3,),(2,1)和标量值起作用. (6,)不会.

In this case the indices define a (2,3) shape, so (2,3),(3,), (2,1), and scalar values work. (6,) does not.

在这种情况下,add.at将(2,3)数组映射到dW的(2,2)子数组上.

In this case, add.at is mapping a (2,3) array onto a (2,2) subarray of dW.

这篇关于np.add.at与数组建立索引的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆