如何编写包装器以修复函数中的任意参数 [英] How to write a wrapper to fix arbitrary parameters in a function

查看:86
本文介绍了如何编写包装器以修复函数中的任意参数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想编写一个曲线拟合脚本,使我可以修复以下形式的函数的参数:

I would like to write a curve-fitting script that allows me to fix parameters of a function of the form:

def func(x, *p):
    assert len(p) % 2 == 0
    fval = 0
    for j in xrange(0, len(p), 2):
        fval += p[j]*np.exp(-p[j+1]*t)
    return fval

例如,假设我要p = [p1,p2,p3,p4],并且我希望p2和p3为常数A和B(从4参数拟合到2参数拟合).我知道functools.partial不允许我这样做,这就是为什么我要编写自己的包装器.但是我这样做有点麻烦.这是我到目前为止的内容:

For example, let's say I want p = [p1, p2, p3, p4], and I want p2 and p3 to be constant A and B (going from a 4-parameter fit to a 2-parameter fit). I understand that functools.partial doesn't let me do this which is why I want to write my own wrapper. But I am having a bit of trouble doing so. This is what I have so far:

def fix_params(f, t, pars, fix_pars):
    # fix_pars = ((ind1, A), (ind2, B))
    new_pars = [None]*(len(pars) + len(fix_pars))
    for ind, fix in fix_pars:
        new_pars[ind] = fix
    for par in pars:
        for j, npar in enumerate(new_pars):
            if npar == None:
                new_pars[j] = par
                break
    assert None not in new_pars
    return f(t, *new_pars)

我认为与此有关的问题是,scipy.optimize.curve_fit不能与通过此类包装程序传递的函数一起很好地工作.我应该如何解决这个问题?

The problem with this I think is that, scipy.optimize.curve_fit won't work well with a function passed through this kind of wrapper. How should I get around this?

推荐答案

所以我认为我有一些可行的方法.也许有办法对此进行改进.

So I think I have something workable. Maybe there is a way to improve on this.

这是我的代码(没有所有异常处理):

Here is my code (without all the exception handling):

def func(x, *p):
    fval = 0
    for j in xrange(0, len(p), 2):
        fval += p[j]*np.exp(-p[j+1]*x)
    return fval

def fix_params(f, fix_pars):
    # fix_pars = ((1, A), (2, B))
    def new_func(x, *pars):
        new_pars = [None]*(len(pars) + len(fix_pars))
        for j, fp in fix_pars:
            new_pars[j] = fp
        for par in pars:
            for j, npar in enumerate(new_pars):
                if npar is None:
                    new_pars[j] = par
                    break
        return f(x, *new_pars)
    return new_func

p1 = [1, 0.5, 0.1, 1.2]
pfix = ((1, 0.5), (2, 0.1))
p2 = [1, 1.2]

new_func = fix_params(func, pfix)

x = np.arange(10)
dat1 = func(x, *p1)
dat2 = new_func(x, *p2)

if (dat1==dat2).all()
    print "ALL GOOD"

这篇关于如何编写包装器以修复函数中的任意参数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆