numpy multi_dot比numpy.dot慢吗? [英] How is numpy multi_dot slower than numpy.dot?

查看:138
本文介绍了numpy multi_dot比numpy.dot慢吗?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试优化一些执行大量顺序矩阵运算的代码.

I'm trying to optimize some code that performs lots of sequential matrix operations.

我想出了numpy.linalg.multi_dot(

I figured numpy.linalg.multi_dot (docs here) would perform all the operations in C or BLAS and thus it would be way faster than going something like arr1.dot(arr2).dot(arr3) and so on.

我真的很惊讶在笔记本上运行以下代码:

I was really surprised running this code on a notebook:

v1 = np.random.rand(2,2)

v2 = np.random.rand(2,2)



%%timeit 
    ​    
v1.dot(v2.dot(v1.dot(v2)))

The slowest run took 9.01 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 3.14 µs per loop



%%timeit        ​

np.linalg.multi_dot([v1,v2,v1,v2])

The slowest run took 4.67 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 32.9 µs per loop

要发现使用multi_dot进行相同操作的速度要慢10倍左右.

To find out that the same operation is about 10x slower using multi_dot.

我的问题是:

  • 我想念什么吗?有什么意义吗?
  • 还有另一种优化顺序矩阵运算的方法吗?
  • 使用cython是否应该获得相同的行为?

推荐答案

这是因为您的测试矩阵太小且太规则;确定最快评估顺序的开销可能会超过潜在的性能提升.

It's because your test matrices are too small and too regular; the overhead in figuring out the fastest evaluation order may outweights the potential performance gain.

使用文档中的示例:

import numpy as snp
from numpy.linalg import multi_dot

# Prepare some data
A = np.random.rand(10000, 100)
B = np.random.rand(100, 1000)
C = np.random.rand(1000, 5)
D = np.random.rand(5, 333)

%timeit -n 10 multi_dot([A, B, C, D])
%timeit -n 10 np.dot(np.dot(np.dot(A, B), C), D)
%timeit -n 10 A.dot(B).dot(C).dot(D)

结果:

10 loops, best of 3: 12 ms per loop
10 loops, best of 3: 62.7 ms per loop
10 loops, best of 3: 59 ms per loop

multi_dot通过评估标量乘法最少的最快乘法顺序来提高性能.

multi_dot improves performance by evaluating the fastest multiplication order in which there are least scalar multiplications.

在上述情况下,默认的正则乘法阶数((AB)C)D被评估为A((BC)D)-从而将1000x100 @ 100x1000乘法减小为1000x100 @ 100x333,从而至少减少了2/3标量乘法.

In the above case, the default regular multiplication order ((AB)C)D is evaluated as A((BC)D)--so that a 1000x100 @ 100x1000 multiplication is reduced to 1000x100 @ 100x333, cutting down at least 2/3 scalar multiplications.

您可以通过测试进行验证

You can verify this by testing

%timeit -n 10 np.dot(A, np.dot(np.dot(B, C), D))
10 loops, best of 3: 19.2 ms per loop

这篇关于numpy multi_dot比numpy.dot慢吗?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆