使用if语句优化作用于numpy数组的函数 [英] Optimize a function that acts on a numpy array with an if statement
问题描述
假设我有一个类似的代码:
Suppose I have a code like:
import numpy as np
def value_error(x):
if x > 10:
return 0.
else:
return np.sin(x)
如果调用numpy数组,这可能会给我一个ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
.
This could give me a ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
if called upon an numpy array.
现在我可以改为:
def alright(x):
return np.sin(x) * (x <= 10.)
print alright(np.ones(100) * 100)
print value_error(np.ones(100) * 10)
我的功能(在本例中为np.sin
)可能是一个昂贵的功能.但是,它会调用x
的每个元素,即使我知道答案的因为x > 10
的元素也是如此,而无需进行昂贵的调用.我如何才能两全其美?
My function (in this case np.sin
) could be an expensive one. It is, however, called for every element of x
, even ones where I know the answer because x > 10
, without an expensive call. How can I get the best of both worlds?
推荐答案
这是一个基于掩码的掩码,只能在有效掩码上与np.sin
一起使用-
Here's a mask based one that operates with np.sin
only on the valid ones -
out = np.zeros(x.shape)
mask = x <= 10
out[mask] = np.sin(x[mask])
利用 numexpr
模块更快transcendental
操作-
import numexpr as ne
out = np.zeros(x.shape)
mask = x <= 10
x_masked = x[mask]
out[mask] = ne.evaluate('sin(x_masked)')
这篇关于使用if语句优化作用于numpy数组的函数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!