Python:用于多维数组的numpy.dot/numpy.tensordot [英] Python: numpy.dot / numpy.tensordot for multidimensional arrays

查看:78
本文介绍了Python:用于多维数组的numpy.dot/numpy.tensordot的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在优化我的反向传播算法的实现,以训练神经网络.我正在研究的方面之一是对数据点集(输入/输出向量)执行矩阵运算,这是由numpy库优化的批处理过程,而不是遍历每个数据点.

I'm optimising my implementation of the back-propagation algorithm to train a neural network. One of the aspects I'm working on is performing the matrix operations on the set of datapoints (input/output vector) as a batch process optimised by the numpy library instead of looping through every datapoint.

在原始算法中,我执行了以下操作:

In my original algorithm I did the following:

for datapoint in datapoints:
  A = ... (created out of datapoint info)
  B = ... (created out of datapoint info)

  C = np.dot(A,B.transpose())
____________________

A: (7,1) numpy array
B: (6,1) numpy array
C: (7,6) numpy array

然后,我将所述矩阵扩展为张量,其中第一个形状索引将引用数据集.如果我有3个数据集(为简单起见),则矩阵如下所示:

I then expanded said matrices to tensors, where the first shape index would refer to the dataset. If I have 3 datasets (for simplicity purposes), the matrices would look like this:

A: (3,7,1) numpy array
B: (3,6,1) numpy array
C: (3,7,6) numpy array

仅使用np.tensordot或其他numpy操纵,如何生成C?

我认为答案看起来像这样:

I assume the answer would look something like this:

C = np.tensordot(A.[some manipulation], B.[some manipulation], axes = (...))

(这是更复杂的应用程序的一部分,我构造事物的方式不再灵活.如果找不到解决方案,我将仅遍历数据集并对每个数据集执行乘法运算)

(This is a part of a much more complex application, and the way I'm structuring things is not flexible anymore. If I find no solution I will only loop through the datasets and perform the multiplication for each dataset)

推荐答案

我们可以使用

We can use np.einsum -

c = np.einsum('ijk,ilm->ijl',a,b)

由于最后一个轴是单轴,所以最好使用切片数组-

Since the last axes are singleton, you might be better off with sliced arrays -

c = np.einsum('ij,il->ijl',a[...,0],b[...,0])

使用 np.matmul/@-operator -

c = a@b.swapaxes(1,2)

这篇关于Python:用于多维数组的numpy.dot/numpy.tensordot的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆