Keras中的交叉验证 [英] Cross Validation in Keras

查看:136
本文介绍了Keras中的交叉验证的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在Keras中实现一个多层感知器,并使用scikit-learn进行交叉验证.为此,我受到问题

I'm implementing a Multilayer Perceptron in Keras and using scikit-learn to perform cross-validation. For this, I was inspired by the code found in the issue Cross Validation in Keras

from sklearn.cross_validation import StratifiedKFold

def load_data():
    # load your data using this function

def create model():
    # create your model using this function

def train_and_evaluate__model(model, data[train], labels[train], data[test], labels[test)):
    # fit and evaluate here.

if __name__ == "__main__":
    X, Y = load_model()
    kFold = StratifiedKFold(n_splits=10)
    for train, test in kFold.split(X, Y):
        model = None
        model = create_model()
        train_evaluate(model, X[train], Y[train], X[test], Y[test])

在我对神经网络的研究中,我了解到神经网络的知识表示在突触权重中,并且在网络跟踪过程中,权重被更新以减少网络错误率并提高其性能.(就我而言,我正在使用监督学习)

In my studies on neural networks, I learned that the knowledge representation of the neural network is in the synaptic weights and during the network tracing process, the weights that are updated to thereby reduce the network error rate and improve its performance. (In my case, I'm using Supervised Learning)

为了更好地训练和评估神经网络性能,一种常用的方法是交叉验证,该验证返回数据集的分区以进行模型的训练和评估.

For better training and assessment of neural network performance, a common method of being used is cross-validation that returns partitions of the data set for training and evaluation of the model.

我的疑问是...

在此代码段中:

for train, test in kFold.split(X, Y):
    model = None
    model = create_model()
    train_evaluate(model, X[train], Y[train], X[test], Y[test])

我们为每个生成的分区定义,训练和评估一个新的神经网络吗?

We define, train and evaluate a new neural net for each of the generated partitions?

如果我的目标是为整个数据集微调网络,为什么定义单个神经网络并使用生成的分区对其进行训练是不正确的?

If my goal is to fine-tune the network for the entire dataset, why is it not correct to define a single neural network and train it with the generated partitions?

也就是说,为什么这段代码是这样的?

That is, why is this piece of code like this?

for train, test in kFold.split(X, Y):
    model = None
    model = create_model()
    train_evaluate(model, X[train], Y[train], X[test], Y[test])

不是吗?

model = None
model = create_model()
for train, test in kFold.split(X, Y):
    train_evaluate(model, X[train], Y[train], X[test], Y[test])

我是否了解代码的工作方式错误?还是我的理论?

Is my understanding of how the code works wrong? Or my theory?

谢谢!

推荐答案

如果我的目标是为整个数据集微调网络

If my goal is to fine-tune the network for the entire dataset

目前尚不清楚您所说的微调"是什么意思,甚至不清楚您执行交叉验证(CV)的目的是什么;通常,简历具有以下目的之一:

It is not clear what you mean by "fine-tune", or even what exactly is your purpose for performing cross-validation (CV); in general, CV serves one of the following purposes:

  • 模型选择(选择超参数的值)
  • 模型评估

由于您没有在代码中为超参数选择定义任何搜索网格,因此您似乎在使用CV来获得模型的预期性能(误差,准确性等).

Since you don't define any search grid for hyperparameter selection in your code, it would seem that you are using CV in order to get the expected performance of your model (error, accuracy etc).

无论如何,无论出于何种原因使用CV,第一个代码段都是正确的;您的第二个片段

Anyway, for whatever reason you are using CV, the first snippet is the correct one; your second snippet

model = None
model = create_model()
for train, test in kFold.split(X, Y):
    train_evaluate(model, X[train], Y[train], X[test], Y[test])

将在不同的分区上依次 训练模型(即在分区#1上进行训练,然后继续在分区#2上进行训练等),这实际上只是对整个数据集进行训练,并且当然不是交叉验证...

will train your model sequentially over the different partitions (i.e. train on partition #1, then continue training on partition #2 etc), which essentially is just training on your whole data set, and it is certainly not cross-validation...

也就是说,在CV之后的最后一步通常只是隐含的(初学者经常会错过)是,在您对选择的超参数和/或模型性能满意后,您的简历程序,您返回并再次训练模型,这次使用 entire 可用数据.

That said, a final step after the CV which is often only implied (and frequently missed by beginners) is that, after you are satisfied with your chosen hyperparameters and/or model performance as given by your CV procedure, you go back and train again your model, this time with the entire available data.

这篇关于Keras中的交叉验证的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆