Numba:不支持单元格变量 [英] Numba : cell vars are not supported

查看:40
本文介绍了Numba:不支持单元格变量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想用 numba 来加速这个功能:

I'd like to use numba to speed up this function:

from numba import jit
@jit
def rownowaga_numba(u, v):
    wymiar_x = len(u)
    wymiar_y = len(u[1])
    f = [[[0 for j in range(wymiar_y)] for i in range(wymiar_x)] for k in range(9)]
    cx = [0., 1., 0., -1., 0., 1., -1., -1., 1.]
    cy = [0., 0., 1., 0., -1., 1., 1., -1., -1.]
    w = [4./9, 1./9, 1./9, 1./9, 1./9, 1./36, 1./36, 1./36, 1./36] 
    for i in range( wymiar_x):
        for j in range (wymiar_y):
            for k in range(9):
                up = u[i][j]
                vp = v[i][j]
                udot = (up**2 + vp**2)
                cu = up*cx[k] + vp*cy[k]
                f[k][i][j] =  w[k] + w[k]*(3.0*cu + 4.5*cu**2 - 1.5*udot)
     return f

我用这样的数据测试它的地方:

Where i test it with such data:

import timeit
import math as m

u = [[m.sin(i) + m.cos(j) for j in range(40)] for i in range(1000)]
y = [[m.sin(i) + m.cos(j) for j in range(40)] for i in range(1000)]

t0 = timeit.default_timer()

for i in range (10):
    f = rownowaga_pypy(u,y)

dt = timeit.default_timer() - t0
print('loop time:', dt)

我收到此错误:

    Traceback (most recent call last):
  File "C:\Users\Ricevind\Desktop\PyPy\Skrypty\Rownowaga.py", line 29, in <module>
    f = rownowaga_pypy(u,y)
  File "C:\pyzo2014a\lib\site-packages\numba\dispatcher.py", line 171, in _compile_for_args
    return self.compile(sig)
  File "C:\pyzo2014a\lib\site-packages\numba\dispatcher.py", line 348, in compile
    flags=flags, locals=self.locals)
  File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 637, in compile_extra
    return pipeline.compile_extra(func)
  File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 356, in compile_extra
    raise e
  File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 351, in compile_extra
    bc = self.extract_bytecode(func)
  File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 343, in extract_bytecode
    bc = bytecode.ByteCode(func=self.func)
  File "C:\pyzo2014a\lib\site-packages\numba\bytecode.py", line 343, in __init__
    raise NotImplementedError("cell vars are not supported")
NotImplementedError: cell vars are not supported

我最感兴趣的是不支持单元格变量"的含义,因为 Google 没有返回任何有意义的结果.

I'm mostly interested in the meaning of "cell vars are not supported" as Google returns no meaning results.

推荐答案

Numba 目前在嵌套列表列表中的效果不是特别好(至少从 v0.21 开始).我相信这就是单元格变量"错误所指的内容,但我不是 100% 确定.下面,我将所有内容都转换为 numpy 数组,以使代码能够被 numba 优化:

Numba does not work particularly well currently on nested list of lists (as of v0.21 at least). I believe that this is what the 'cell vars' error is referring to, but I'm not 100% sure. Below, I convert everything to numpy arrays to enable the code to be optimized by numba:

import numpy as np
import numba as nb
import math

def rownowaga(u, v):
    wymiar_x = len(u)
    wymiar_y = len(u[1])
    f = [[[0 for j in range(wymiar_y)] for i in range(wymiar_x)] for k in range(9)]
    cx = [0., 1., 0., -1., 0., 1., -1., -1., 1.]
    cy = [0., 0., 1., 0., -1., 1., 1., -1., -1.]
    w = [4./9, 1./9, 1./9, 1./9, 1./9, 1./36, 1./36, 1./36, 1./36] 
    for i in range( wymiar_x):
        for j in range (wymiar_y):
            for k in range(9):
                up = u[i][j]
                vp = v[i][j]
                udot = (up**2 + vp**2)
                cu = up*cx[k] + vp*cy[k]
                f[k][i][j] =  w[k] + w[k]*(3.0*cu + 4.5*cu**2 - 1.5*udot)
    return f

# Pull these out so that numba treats them as constant arrays
cx = np.array([0., 1., 0., -1., 0., 1., -1., -1., 1.])
cy = np.array([0., 0., 1., 0., -1., 1., 1., -1., -1.])
w = np.array([4./9, 1./9, 1./9, 1./9, 1./9, 1./36, 1./36, 1./36, 1./36]) 

@nb.jit(nopython=True)
def rownowaga_numba(u, v):
    wymiar_x = u.shape[0]
    wymiar_y = u[1].shape[0]
    f = np.zeros((9, wymiar_x, wymiar_y))

    for i in xrange( wymiar_x):
        for j in xrange (wymiar_y):
            for k in xrange(9):
                up = u[i,j]
                vp = v[i,j]
                udot = (up*up + vp*vp)
                cu = up*cx[k] + vp*cy[k]
                f[k,i,j] =  w[k] + w[k]*(3.0*cu + 4.5*cu**2 - 1.5*udot)
    return f

现在让我们设置一些测试数组:

Now let's set up some test arrays:

u = [[math.sin(i) + math.cos(j) for j in range(40)] for i in range(1000)]
y = [[math.sin(i) + math.cos(j) for j in range(40)] for i in range(1000)]

u_np = np.array(u)
y_np = np.array(y)

首先让我们验证我的 numba 代码给出与 OP 代码相同的答案:

First let's verify that my numba code is giving the same answer as the OP's code:

f1 = rownowaga(u, y)
f2 = rownowaga_numba(u_np, y_np)

来自 ipython 笔记本:

From an ipython notebook:

In [13]: np.allclose(f2, np.array(f1))
Out[13]:
True

现在让我们在笔记本电脑上计时:

And now let's time things on my laptop:

In [15] %timeit f1 = rownowaga(u, y)
1 loops, best of 3: 288 ms per loop


In [16] %timeit f2 = rownowaga_numba(u_np, y_np)
1000 loops, best of 3: 973 µs per loop

因此,我们以最少的代码更改获得了 300 倍的不错的加速.请注意,我使用的是 0.22 之前的 Numba 每晚版本:

So we get a nice 300x speed-up with minimal code changes. Just to note, I'm using a nightly build of Numba from a little before 0.22:

In [16]: nb.__version__
Out[16]:
'0.21.0+137.gac9929d'

这篇关于Numba:不支持单元格变量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆