numpy数组计算比等效的Java代码慢 [英] Numpy array computation slower than equivalent Java code

查看:100
本文介绍了numpy数组计算比等效的Java代码慢的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试使用Python处理大型2D数组,但这非常慢。例如:

I am trying to work with large 2D arrays in Python, but it's very slow. For example:

start = time.time()
result = numpy.empty([5000, 5000])

for i in range(5000):
    for j in range(5000):
        result[i, j] = (i * j) % 10

end = time.time()
print(end - start) # 8.8 s

Java中的相同程序要快得多:

Same program in Java is much faster:

long start = System.currentTimeMillis();
int[][] result = new int[5000][5000];

for (int i = 0; i < 5000; i++) {
    for (int j = 0; j < 5000; j++) {
        result[i][j] = (i * j) % 10;
    }
}

long end = System.currentTimeMillis();
System.out.println(end - start); // 121 ms

这是因为Python是解释语言吗?有什么办法可以改善吗?还是为什么Python在处理矩阵,人工智能等方面如此受欢迎?

It's because Python is interpreted language? Is there any way to improve it? Or why Python is so popular for working with matrices, artificial intelligence, etc.?

推荐答案

阅读最后看看NumPy如何比Java代码好5倍。

Read to the end to see how NumPy can outperform your Java code by 5x.

numpy 的优势在于矢量化计算 。您的Python代码依赖于解释循环,迭代循环往往很慢。

numpy's strength lies in vectorized computations. Your Python code relies on interpreted loops, and iterpreted loops tend to be slow.

我将您的Python代码重写为矢量化计算,并立即将其加速了两倍。 〜16:

I rewrote your Python code as a vectorized computation and that immediately sped it up by a factor of ~16:

In [41]: v = np.arange(5000)

In [42]: %timeit np.outer(v, v) % 10
1 loop, best of 3: 544 ms per loop

在原位计算%10 而不是创建新数组将速度提高了20%:

Computing % 10 in place instead of creating a new array speeds things up by another 20%:

In [37]: def f(n):
    ...:     v = np.arange(n)
    ...:     a = np.outer(v, v)
    ...:     a %= 10
    ...:     return a
    ...:

In [39]: %timeit f(5000)
1 loop, best of 3: 437 ms per loop

编辑1:以32位而不是64位执行计算(以匹配Java代码)基本上与Java的性能相匹配-h / t与@ user2357112指出了这一点:

edit 1: Doing the computations in 32 bits instead of 64 (to match your Java code) basically matches the performance of Java — h/t to @user2357112 for pointing this out:

In [50]: def f(n):
    ...:  v = np.arange(n, dtype=np.int32)
    ...:  a = np.outer(v, v)
    ...:  a %= 10
    ...:  return a
    ...:

In [51]: %timeit f(5000)
10 loops, best of 3: 126 ms per loop

编辑2:并通过一些工作,可以使此代码比Java实现快5倍(此处 ne 指的是 numexpr 模块):

edit 2: And with a little bit of work we can make this code about 5x faster than your Java implementation (here ne refers to the numexpr module):

In [69]: v = np.arange(5000, dtype=np.int32)

In [70]: vt = v[np.newaxis].T

In [71]: %timeit ne.evaluate('v * vt % 10')
10 loops, best of 3: 25.3 ms per loop

编辑3:请确保还要看看由@ max9111提供的答案

edit 3: Please make sure to also take a look at the answer given by @max9111.

这篇关于numpy数组计算比等效的Java代码慢的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆