分段线性回归的优化 [英] Optimization on piecewise linear regression

查看:80
本文介绍了分段线性回归的优化的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试创建分段线性回归以最小化MSE(最小平方误差),然后直接使用线性回归.该方法应使用动态编程来计算不同的分段大小和组的组合,以实现整体MSE.我认为算法运行时为O(n²),我想知道是否有办法将其优化为O(nLogN)?

I am trying to create a piecewise linear regression to minimize the MSE(minimum square errors) then using linear regression directly. The method should be using dynamic programming to calculate the different piecewise sizes and combinations of groups to achieve the overall MSE. I think the algorithm runtime is O(n²) and I wonder if there are ways to optimize it to O(nLogN)?

import numpy as np
from sklearn.metrics import mean_squared_error
from sklearn import linear_model
import pandas as pd
import matplotlib.pyplot as plt

x = [3.4, 1.8, 4.6, 2.3, 3.1, 5.5, 0.7, 3.0, 2.6, 4.3, 2.1, 1.1, 6.1, 4.8,3.8]
y = [26.2, 17.8, 31.3, 23.1, 27.5, 36.5, 14.1, 22.3, 19.6, 31.3, 24.0, 17.3, 43.2, 36.4, 26.1]
dataset = np.dstack((x,y))
dataset = dataset[0]
d_arg = np.argsort(dataset[:,0])
dataset = dataset[d_arg]

def calc_error(dataset):
    lr_model = linear_model.LinearRegression()
    x = pd.DataFrame(dataset[:,0])
    y = pd.DataFrame(dataset[:,1])
    lr_model.fit(x,y)
    predictions = lr_model.predict(x)
    mse = mean_squared_error(y, predictions)
    return mse

#n is the number of points , m is the number of groups, k is the minimum number of points in a group
#(15,5,3)returns 【【3,3,3,3,3】】
#(15,5,2) returns [[2,4,3,3,3],[3,2,4,2,4],[4,2,3,3,3]....]
def all_combination(n,m,k):
    result = []
    if n < k*m:
        print('There are not enough elements to split.')
        return
    combination_bottom = [k for q in range(m)] 
    #add greedy algorithm here?
    if n == k*m:
        result.append(combination_bottom.copy())
    else:
        combination_now = [combination_bottom.copy()]
        j = k*m+1
        while j < n+1:
            combination_last = combination_now.copy()
            combination_now = []
            for x in combination_last:
                for i in range (0, m):
                    combination_new = x.copy()
                    combination_new[i] = combination_new[i]+1
                    combination_now.append(combination_new.copy())
            j += 1
        else:
            for x in combination_last:
                for i in range (0, m):
                    combination_new = x.copy()
                    combination_new[i] = combination_new[i]+1
                    if combination_new not in result:
                        result.append(combination_new.copy())
            
    return result #2-d list

def calc_sum_error(dataset,cb):#cb = combination
    mse_sum = 0    
    for n in range(0,len(cb)):
        if n == 0:
            low = 0
            high = cb[0]
        else:
            low = 0
            for i in range(0,n):
                low += cb[i]
            high = low + cb[n]
        mse_sum += calc_error(dataset[low:high])
    return mse_sum
            
    
    
#k is the number of points as a group
def best_piecewise(dataset,k):
    lenth = len(dataset)
    max_split = lenth // k
    min_mse = calc_error(dataset)
    split_cb = []
    all_cb = []
    for i in range(2, max_split+1):
        split_result = all_combination(lenth, i, k)
        all_cb += split_result
        for cb in split_result:
            tmp_mse = calc_sum_error(dataset,cb)
            if tmp_mse < min_mse:
                min_mse = tmp_mse
                split_cb = cb
    return min_mse, split_cb, all_cb

min_mse, split_cb, all_cb = best_piecewise(dataset, 2)

print('The best split of the data is '+str(split_cb))
print('The minimum MSE value is '+str(min_mse))

x = np.array(dataset[:,0])
y = np.array(dataset[:,1])

plt.plot(x,y,"o")
for n in range(0,len(split_cb)):
    if n == 0:
        low = 0 
        high = split_cb[n]
    else:
        low = 0
        for i in range(0,n):
            low += split_cb[i]
        high = low + split_cb[n]
    x_tmp = pd.DataFrame(dataset[low:high,0])
    y_tmp = pd.DataFrame(dataset[low:high,1])
    lr_model = linear_model.LinearRegression()
    lr_model.fit(x_tmp,y_tmp)
    y_predict = lr_model.predict(x_tmp)
    plt.plot(x_tmp, y_predict, 'g-')

plt.show()

如果我在任何地方都不清楚的话,请告诉我.

Please let me know if I didn't make it clear in any part.

推荐答案

我花了一些时间才意识到,您所描述的问题正是决策树回归器试图解决的问题.

It took me some time to realize, that the problem you're describing is exactly what a decision tree regressor tries to solve.

不幸的是,构造最优决策树是NP难的,这意味着即使使用动态编程,也无法将运行时降低到O(NlogN)之类.

Unfortunately, construction of an optimal decision tree is NP-hard, meaning that even with dynamic programming you can't bring the runtime down to anything like O(NlogN).

好消息是,您可以直接使用任何维护良好的决策树实现, sklearn.tree 模块的DecisionTreeRegressor ,可以肯定地获得O(NlogN)时间复杂度的最佳性能.要强制每个组最少点数,请使用 min_samples_leaf 参数.您还可以控制其他一些属性,例如maximun of no.带有 max_leaf_nodes 的组,使用 criterion 等优化其他损失函数.

Good news is that you can directly use any well maintained decision tree implementation, DecisionTreeRegressor of sklearn.tree module for example, and can be certain about obtaining best possible performance in O(NlogN) time complexity. To enforce a minimum number of points per group, use min_samples_leaf parameter. You can also control several other properties like maximun of no. groups with max_leaf_nodes, optimization w.r.t different loss functions using criterion etc.

如果您好奇Scikit-learn的决策树与您的算法(即代码中的 split_cb )所学习的决策树相比如何?

If you're curious how Scikit-learn's decision tree compare with the one learnt by your algorithm (i.e. split_cb in your code):

X = np.array(x).reshape(-1,1)
dt = DecisionTreeRegressor(min_samples_leaf=MIN_SIZE).fit(X,y)
split_cb = np.unique(dt.apply(X),return_counts=True)[1]

然后使用与您相同的绘图代码.请注意,由于您的时间复杂度大大高于O(NlogN)*,因此您的实现通常会比scikit-learn的贪婪算法发现更好的分割.

And then use the same plotting code you use. Do note that since your time complexity is considerably higher than O(NlogN)*, your implementation will often find better splits than the scikit-learn's greedy algorithm.

[1] Hyafil,L.&Rivest,R.L.(1976).构造最佳的二元决策树是np-complete.信息处理快报,第5(1)页,第15-17页

[1] Hyafil, L., & Rivest, R. L. (1976). Constructing optimal binary decision trees is np-complete. Information Processing Letters, 5(1), 15–17

*尽管我不确定您的实现的确切时间复杂度,但肯定比O(N ^ 2)还差, all_combination(21,4,2)花费了5倍以上分钟.

*Although I'm not sure about the exact time complexity of your implementation, it's quite certainly worse than O(N^2), all_combination(21,4,2) took more than 5 mins.

这篇关于分段线性回归的优化的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆