有效地将每一行的元素相乘 [英] Efficiently multiply elements of each row together

查看:50
本文介绍了有效地将每一行的元素相乘的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

给定一个大小为 (n, 3)n 大约为 1000 的 ndarray,如何快速将每行的所有元素相乘?下面的(不优雅的)第二种解决方案运行时间大约为 0.3 毫秒,是否可以改进?

Given a ndarray of size (n, 3) with n around 1000, how to multiply together all elements for each row, fast? The (inelegant) second solution below runs in about 0.3 millisecond, can it be improved?

# dummy data
n = 999
a = np.random.uniform(low=0, high=10, size=n).reshape(n/3,3)

# two solutions
def prod1(array):
    return [np.prod(row) for row in array]

def prod2(array):
    return [row[0]*row[1]*row[2] for row in array]

# benchmark
start = time.time()
prod1(a)
print time.time() - start
# 0.0015

start = time.time()
prod2(a)
print time.time() - start
# 0.0003

推荐答案

进一步提高性能

首先是一般的经验法则.您正在使用数值数组,因此请使用数组而不是列表.列表可能看起来有点像一个通用数组,但在后端完全不同,绝对不适用于大多数数值计算.

At first a general rule of thumb. You are working with numerical arrays, so use arrays and not lists. Lists may look somewhat like a general array, but beeing completely different in the backend and absolutely not suteable for most numerical calculations.

如果您使用 Numpy-Arrays 编写一个简单的代码,您可以通过简单地对其进行抖动来提高性能,如上图所示.如果您使用列表,您可以或多或少地重写您的代码.

If you write a simple code using Numpy-Arrays you can gain performance by simply jitting it as shown beyond. If you use lists you can more or less rewrite your code.

import numpy as np
import numba as nb

@nb.njit(fastmath=True)
def prod(array):
  assert array.shape[1]==3 #Enable SIMD-Vectorization (adding some performance)
  res=np.empty(array.shape[0],dtype=array.dtype)
  for i in range(array.shape[0]):
    res[i]=array[i,0]*array[i,1]*array[i,2]

  return res

使用 np.prod(a,axis=1) 不是一个坏主意,但性能并不是很好.对于只有 1000x3 的数组,函数调用开销非常大.当在另一个 jitted 函数中使用 jitted prod 函数时,可以完全避免这种情况.

Using np.prod(a, axis=1) isn't a bad idea, but the performance isn't really good. For an array with only 1000x3 the function call overhead is quite significant. This can be completely avoided, when using the jitted prod function in another jitted function.

基准

# The first call to the jitted function takes about 200ms compilation overhead. 
#If you use @nb.njit(fastmath=True,cache=True) you can cache the compilation result for every successive call.
n=999
prod1   = 795  µs
prod2   = 187  µs
np.prod = 7.42 µs
prod      0.85 µs

n=9990
prod1   = 7863 µs
prod2   = 1810 µs
np.prod = 50.5 µs
prod      2.96 µs

这篇关于有效地将每一行的元素相乘的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆