如何使用 scipy.optimize.curve_fit 来使用变量列表 [英] How to use scipy.optimize.curve_fit to use lists of variable

查看:61
本文介绍了如何使用 scipy.optimize.curve_fit 来使用变量列表的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有时间变化的数据跟踪,我想将函数拟合到其中.函数的输入是列表,我希望 curve_fit 优化列表中的所有值以拟合曲线.我已经到了-

I have time varying data trace which I want to fit a function to. The inputs to the functions are lists and I want the curve_fit to optimize all values in the list to fit the curve. I have gotten so far-

from scipy.optimize import curve_fit
from matplotlib.pylab import plt
from numpy import exp

def ffunc2(x, a, b):
    counter = 0
    return_value = 0
    while counter < len(a):
        return_value += a[counter] * exp(b[counter] * x)
        counter += 1
    return return_value

# INITIAL DATA
x = [1, 2, 3, 5]
y = [1, 8, 81, 125]


number_variable = 2

# INTIAL GUESS
p0 = []
counter = 0
while counter < number_variable:
    p0.append(0.0)
    counter += 1

p, _ = curve_fit(ffunc2, x, y, p0=[0.0, 0.0])

我想创建一个循环,它通过最小化错误来为我提供最大数量的变量的最佳拟合.

I want to create a loop which iterates such that it gives me the best fit with maximum number of variables by minimizing the error.

我也发现了这个讨论 -使用 scipy curve_fit 获取可变数量的参数

I have found this discussion as well - Using scipy curve_fit for a variable number of parameters

from numpy import exp
from scipy.optimize import curve_fit

def wrapper_fit_func(x, N, *args):
    a, b, c = list(args[0][:N]), list(args[0][N:2*N]), list(args[0][2*N:3*N])
    return fit_func(x, a, b)

def fit_func(x, a, b):
    counter = 0
    return_value = 0
    while counter < len(a):
        return_value += a[counter] * exp(b[counter] * x)
        counter += 1
    return return_value


x = [1, 2, 3, 5]
y = [1, 8, 81, 125]
params_0 = [0,1.0,2.0,3.0,4.0,5.0]

popt, pcov = curve_fit(lambda x, *params_0: wrapper_fit_func(x, 3, params_0), x, y, p0=params_0)

但得到一个错误 -´´´文件C:\python\lib\site-packages\scipy\optimize\minpack.py",第 387 行,最小平方raise TypeError('不正确的输入:N=%s 不能超过 M=%s' % (n, m))类型错误:不正确的输入:N=6 不得超过 M=4'''

But get an error -´´´ File "C:\python\lib\site-packages\scipy\optimize\minpack.py", line 387, in leastsq raise TypeError('Improper input: N=%s must not exceed M=%s' % (n, m)) TypeError: Improper input: N=6 must not exceed M=4 ´´´

推荐答案

我能够直接使用 sci_py.optimize.least_squares 解决这个问题,因为它接受元组作为输入而不是直接变量.但我必须定义误差函数.我认为它有助于解决我现在的问题.

I was able to solve this be using sci_py.optimize.least_squares directly as it takes in tuple as an input and not variables directly. But I do have to define the error function. I would assume it help solve my problem as of now.

from scipy.optimize import least_squares
from matplotlib.pylab import plt
from numpy import exp
import numpy as np


# define function to fit
def ffunc2(a, x):
    counter = 0
    return_value = 0
    while counter < len(a):
        return_value += a[counter] * exp(x * a[counter + 1])
        counter += 2

    return return_value


def error_func(tpl, x, y):
    return ffunc2(tpl,x) - y


# INPUT DATA
x = np.array([1, 2, 3, 5])
y = np.array([0.22103418, 0.24428055, 0.26997176, 0.32974425,])

# INITIAL GUESS
p0 = (1, 1)*10
output = least_squares(error_func, x0=p0, jac='2-point', bounds=(0, np.inf), method='trf', ftol=1e-08,
                       xtol=1e-08, gtol=1e-08, x_scale=1.0, loss='linear', f_scale=1.0, diff_step=None,
                       tr_solver=None,
                       tr_options={}, jac_sparsity=None, max_nfev=None, verbose=0, args=(x, y))

tpl_final = output.x
print (tpl_final)

final_curve = ffunc2(tpl_final,x)

plt.plot(x, y, 'r-', x, final_curve, 'g-')
plt.show()

这篇关于如何使用 scipy.optimize.curve_fit 来使用变量列表的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆