如何使用GridSearchCV获得每组参数的预测? [英] How to get predictions for each set of parameters using GridSearchCV?

查看:1028
本文介绍了如何使用GridSearchCV获得每组参数的预测?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试使用带有以下代码的GridSearchCV为NN回归模型找到最佳参数:

I'm trying to find the best parameters for NN regression model using GridSearchCV with following code:

param_grid = dict(optimizer=optimizer, epochs=epochs, batch_size=batches, init=init
grid = GridSearchCV(estimator=model, param_grid=param_grid, scoring='neg_mean_squared_error')
grid_result = grid.fit(input_train, target_train)

pred = grid.predict(input_test)

据我了解,grid.predict(input_test)使用最佳参数来预测给定的输入集.有什么方法可以使用测试集评估每组参数的GridSearchCV?

As I understand, grid.predict(input_test) uses best parameters to predict the given input set. Is there any way to evaluate GridSearchCV for each set of parameters using test set?

实际上,我的测试集包含一些特殊记录,我想测试模型的一般性以及准确性.谢谢你.

Actually, my test set includes some special records and I want to test the generality of the model along with the accuracy. Thank you.

推荐答案

您可以使用自定义迭代器替换GridSearchCV的标准3折cv参数,以生成级联训练和测试数据帧的训练和测试索引.结果,在进行1次交叉验证时,您将在input_train对象上训练模型,并在input_test对象上测试拟合的模型:

You can replace standard 3-folds cv parameter of GridSearchCV with custom iterator, which yields train and test indices of concatenated train and test dataframes. In result, while 1-fold cross validation you'l train your model on input_train objects and test your fitted model on input_test objects:

def modified_cv(input_train_len, input_test_len):
    yield (np.array(range(input_train_len)), 
           np.array(range(input_train_len, input_train_len + input_test_len)))

input_train_len = len(input_train)
input_test_len = len(input_test)
data = np.concatenate((input_train, input_test), axis=0)
target = np.concatenate((target_train, target_test), axis=0)
grid = GridSearchCV(estimator=model, 
                    param_grid=param_grid,
                    cv=modified_cv(input_train_len, input_test_len), 
                    scoring='neg_mean_squared_error')
grid_result = grid.fit(data, target)

通过访问grid_result.cv_results_词典,您会在测试集中看到指定模型参数的所有网格的指标值.

By accessing grid_result.cv_results_ dictionary, you'l see your metrics value on test set for all grid of specified model parameters.

这篇关于如何使用GridSearchCV获得每组参数的预测?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆