优化Julia中的递归函数 [英] Optimize a recursive function in Julia

查看:80
本文介绍了优化Julia中的递归函数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我写了一个Julia代码,该代码计算高斯函数的积分,并且我有一个内核函数,它被一遍又一遍地调用. 根据Julia内置的Profile模块,这是我在实际计算中花费的大部分时间,因此我想看看是否有任何方法可以改进它.

I wrote a Julia code which computes integrals over Gaussian functions and I have a sort-of kernel function which is called over and over again. According to the Julia built-in Profile Module, this is where I spend most of the time during the actual computation and therefore I would like to see if there is any way in which I can improve it.

这是一个递归函数,我以一种简单的方式实现了它.由于我不太习惯递归函数,因此也许有人对如何改进它有一些想法/建议(从纯理论算法的角度和/或从JIT编译器中利用特殊的优化).

It is a recursive function and I implemented it in a kind of straightforward way. As I am not that much used to recursive functions, maybe somebody out there has some ideas/suggestions on how to improve it (both from a purely theoretical algorithmic point of view and/or exploiting special optimizations from the JIT compiler).

您在这里:

"""Returns the integral of an Hermite Gaussian divided by the Coulomb operator."""
function Rtuv{T<:Real}(t::Int, u::Int, v::Int, n::Int, p::Real, RPC::Vector{T})
    if t == u == v == 0
        return (-2.0*p)^n * boys(n,p*norm(RPC)^2)
    elseif u == v == 0
        if t > 1
            return  (t-1)*Rtuv(t-2, u, v, n+1, p, RPC) +
                   RPC[1]*Rtuv(t-1, u, v, n+1, p, RPC)
        else
            return RPC[1]*Rtuv(t-1, u, v, n+1, p, RPC)
        end
    elseif v == 0
        if u > 1
            return  (u-1)*Rtuv(t, u-2, v, n+1, p, RPC) +
                   RPC[2]*Rtuv(t, u-1, v, n+1, p, RPC)
        else
            return RPC[2]*Rtuv(t, u-1, v, n+1, p ,RPC)
        end
    else
        if v > 1
            return  (v-1)*Rtuv(t, u, v-2, n+1, p, RPC)
                   RPC[3]*Rtuv(t, u, v-1, n+1, p, RPC)
        else
            return RPC[3]*Rtuv(t, u, v-1, n+1, p, RPC)
        end
    end
end

不要对功能boys给予太多关注,因为根据探查器,它并不那么重.
只是为了对数字范围有所了解:通常,第一个呼叫来自t+u+v,范围从03,而n总是从0开始.

Don't pay that much attention to the function boys, since according to the profiler it is not that heavy.
Just to give an idea of the range of numbers: usually the first call comes from t+u+v ranging from 0 to 3, while n always starts at 0.

欢呼

对于t,u,v的较小值,生成的版本较慢,我相信原因是因为编译器未优化表达式. 我在这种情况下基准测试很差,没有插值通过的参数.如果做得好,我总是会接受公认的答案中所介绍的方法,所以总是更快!

The generated version is slower for small values of t,u,v, I believe the reason is because expressions are not optimzied by the compiler. I was benchmarking badly for this case, without interpolating the argument passed. By doing it properly I am always faster with the approach explained in the accepted answer, so hurray!

更笼统地说,编译器是否会识别琐碎的情况(例如与零和一相乘)并对其进行优化?

More generally, does the compiler identify trivial cases such as multiplication by zeros and ones and optimize those away?

对自己的答案:用@code_llvm快速检查简单代码似乎不是这种情况.

Answer to myself: from a quick checking of simple code with @code_llvm it seems not to be the case.

推荐答案

也许这适用于您的情况:您可以使用生成的函数记住"整个编译方法,并在第一次调用后摆脱所有递归.

Maybe this works in your case: you can "memoize" whole compiled methods using generated functions and get rid of all recursion after the first call.

由于tuv将保持较小,因此您可以为递归生成完全扩展的代码.为简单起见,假设伪造的实现

Since t, u, and v will stay small, you could generate the fully expanded code for the recursions. Assume for the simplicity a bogus implementation of

boys(n::Int, x::Real) = n + x

然后

function Rtuv_expr(t::Int, u::Int, v::Int, n, p, RPC)
    ninc = :($n + 1)

    if t == u == v == 0
        :((-2.0 * $p)^$n * boys($n, $p * norm($RPC)^2))
    elseif u == v == 0
        if t > 1
            :($(t-1) * $(Rtuv_expr(t-2, u, v, ninc, p, RPC)) +
              $RPC[1] * $(Rtuv_expr(t-1, u, v, ninc, p, RPC)))
        else
            :($RPC[1] * $(Rtuv_expr(t-1, u, v, ninc, p, RPC)))
        end
    elseif v == 0
        if u > 1
            :($(u-1) * $(Rtuv_expr(t, u-2, v, ninc, p, RPC)) +
              $RPC[2] * $(Rtuv_expr(t, u-1, v, ninc, p, RPC)))
        else
            :($RPC[2] * $(Rtuv_expr(t, u-1, v, ninc, p, RPC)))
        end
    else
        if v > 1 
            :($(v-1) * $(Rtuv_expr(t, u, v-2, ninc, p, RPC)) + 
              $RPC[3] * $(Rtuv_expr(t, u, v-1, ninc, p, RPC)))
        else
            :($RPC[3] * $(Rtuv_expr(t, u, v-1, ninc, p, RPC)))
        end
    end
end

将为您生成完全展开的表达式,如下所示:

will generate you fully expanded expressions like this:

julia> Rtuv_expr(1, 2, 1, 0, 0.1, rand(3))
:(([0.868194, 0.928591, 0.295344])[3] * (1 * (([0.868194, 0.928591, 0.295344])[1] * ((-2.0 * 0.1) ^ (((0 + 1) + 1) + 1) * boys(((0 + 1) + 1) + 1, 0.1 * norm([0.868194, 0.928591, 0.295344]) ^ 2))) + ([0.868194, 0.928591, 0.295344])[2] * (([0.868194, 0.928591, 0.295344])[2] * (([0.868194, 0.928591, 0.295344])[1] * ((-2.0 * 0.1) ^ ((((0 + 1) + 1) + 1) + 1) * boys((((0 + 1) + 1) + 1) + 1, 0.1 * norm([0.868194, 0.928591, 0.295344]) ^ 2))))))

我们可以将其填充到生成的函数 Rtuv采用Val类型.对于TUV的每种不同组合,此函数将使用Rtuv_expr编译各自的表达式,然后使用此方法-不再进行递归:

We can stuff that into a generated function Rtuv taking Val types. For each different combination of T, U, and V, this function will use Rtuv_expr to compile the respective expression and from then on use this method -- no recursion anymore:

@generated function Rtuv{T, U, V, X<:Real}(::Type{Val{T}}, ::Type{Val{U}}, ::Type{Val{V}},
                                           n::Int, p::Real, RPC::Vector{X})
    Rtuv_expr(T, U, V, :n, :p, :RPC)
end

您必须使用包裹在Val中的tuv来调用它,但是:

You have to call it with t, u, v wrapped in Val, though:

julia> Rtuv(Val{1}, Val{2}, Val{1}, 0, 0.1, rand(3))
-0.0007782250832001092

如果您测试这样的小循环,

If you test a small loop like this,

for t = 0:3, u = 0:3, v = 0:3
    println(Rtuv(Val{t}, Val{u}, Val{v}, 0, 0.1, [1.0, 2.0, 3.0]))
end

第一次运行将需要一些时间,但是由于使用的方法已经编译,因此运行很快.

it will need some time for the first run, but then go pretty fast, since the used methods are already compiled.

这篇关于优化Julia中的递归函数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆