将Numpy中的矩阵列表相乘 [英] Multiply together list of matrices in Numpy

查看:555
本文介绍了将Numpy中的矩阵列表相乘的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在寻找一种有效的方法来将Numpy中的矩阵列表相乘.我有一个像这样的矩阵:

I'm looking for an efficient way to multiply a list of matrices in Numpy. I have a matrix like this:

import numpy as np
a = np.random.randn(1000, 4, 4)

我想沿长轴矩阵相乘,所以结果是4x4矩阵.显然,我可以这样做:

I want to matrix-multiply along the long axis, so the result is a 4x4 matrix. So clearly I can do:

res = np.identity(4)
for ai in a:
    res = np.matmul(res, ai)

但这太慢了.有没有更快的方法(也许使用einsum或我尚不完全了解的其他功能)?

But this is super-slow. Is there a faster way (perhaps using einsum or some other function that I don't fully understand yet)?

推荐答案

对于幂数为2的堆栈,需要log_2(n) for循环交互的解决方案可能是

A solution that requires log_2(n) for loop interations for stacks with size of powers of 2 could be

while len(a) > 1:
    a = np.matmul(a[::2, ...], a[1::2, ...])

基本上是将两个相邻的矩阵迭代相乘,直到只剩下一个矩阵,每次迭代完成剩余乘法的一半.

which essentially iteratively multiplies two neighbouring matrices together until there is only one matrix left, doing half of the remaining multiplications per iteration.

res = A * B * C * D * ...         # 1024 remaining multiplications

成为

res = (A * B) * (C * D) * ...     # 512 remaining multiplications

成为

res = ((A * B) * (C * D)) * ...   # 256 remaining multiplications

对于非2的幂,您可以对第一个2^n矩阵执行此操作,对其余矩阵使用算法.

For non-powers of 2 you can do this for the first 2^n matrices and use your algorithm for the remaining matrices.

这篇关于将Numpy中的矩阵列表相乘的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆