Keras 中的交叉验证 [英] Cross Validation in Keras

查看:19
本文介绍了Keras 中的交叉验证的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在 Keras 中实现多层感知器并使用 scikit-learn 执行交叉验证.为此,我受到了问题 Cross Validation in凯拉斯

I'm implementing a Multilayer Perceptron in Keras and using scikit-learn to perform cross-validation. For this, I was inspired by the code found in the issue Cross Validation in Keras

from sklearn.cross_validation import StratifiedKFold

def load_data():
    # load your data using this function

def create model():
    # create your model using this function

def train_and_evaluate__model(model, data[train], labels[train], data[test], labels[test)):
    # fit and evaluate here.

if __name__ == "__main__":
    X, Y = load_model()
    kFold = StratifiedKFold(n_splits=10)
    for train, test in kFold.split(X, Y):
        model = None
        model = create_model()
        train_evaluate(model, X[train], Y[train], X[test], Y[test])

在我对神经网络的研究中,我了解到神经网络的知识表示是在突触权重中,在网络跟踪过程中,权重被更新,从而降低网络错误率并提高其性能.(就我而言,我使用的是监督学习)

In my studies on neural networks, I learned that the knowledge representation of the neural network is in the synaptic weights and during the network tracing process, the weights that are updated to thereby reduce the network error rate and improve its performance. (In my case, I'm using Supervised Learning)

为了更好地训练和评估神经网络性能,常用的一种方法是交叉验证,该方法返回数据集的分区,用于模型的训练和评估.

For better training and assessment of neural network performance, a common method of being used is cross-validation that returns partitions of the data set for training and evaluation of the model.

我的疑问是...

在这个代码片段中:

for train, test in kFold.split(X, Y):
    model = None
    model = create_model()
    train_evaluate(model, X[train], Y[train], X[test], Y[test])

我们为每个生成的分区定义、训练和评估一个新的神经网络?

We define, train and evaluate a new neural net for each of the generated partitions?

如果我的目标是针对整个数据集对网络进行微调,为什么定义单个神经网络并使用生成的分区对其进行训练是不正确的?

If my goal is to fine-tune the network for the entire dataset, why is it not correct to define a single neural network and train it with the generated partitions?

也就是说,为什么这段代码是这样的?

That is, why is this piece of code like this?

for train, test in kFold.split(X, Y):
    model = None
    model = create_model()
    train_evaluate(model, X[train], Y[train], X[test], Y[test])

不是吗?

model = None
model = create_model()
for train, test in kFold.split(X, Y):
    train_evaluate(model, X[train], Y[train], X[test], Y[test])

我对代码如何工作的理解有误吗?还是我的理论?

Is my understanding of how the code works wrong? Or my theory?

推荐答案

如果我的目标是为整个数据集微调网络

If my goal is to fine-tune the network for the entire dataset

不清楚您所说的微调"是什么意思,甚至不清楚您执行交叉验证 (CV) 的目的是什么;一般来说,CV 服务于以下目的之一:

It is not clear what you mean by "fine-tune", or even what exactly is your purpose for performing cross-validation (CV); in general, CV serves one of the following purposes:

  • 模型选择(选择超参数的值)
  • 模型评估

由于您没有在代码中为超参数选择定义任何搜索网格,因此您似乎在使用 CV 以获得模型的预期性能(误差、准确性等).

Since you don't define any search grid for hyperparameter selection in your code, it would seem that you are using CV in order to get the expected performance of your model (error, accuracy etc).

无论如何,无论您出于何种原因使用 CV,第一个片段都是正确的;你的第二个片段

Anyway, for whatever reason you are using CV, the first snippet is the correct one; your second snippet

model = None
model = create_model()
for train, test in kFold.split(X, Y):
    train_evaluate(model, X[train], Y[train], X[test], Y[test])

将在不同的分区上顺序训练您的模型(即在分区 #1 上训练,然后在分区 #2 上继续训练等),这本质上只是在整个数据集上进行训练,并且当然不是交叉验证...

will train your model sequentially over the different partitions (i.e. train on partition #1, then continue training on partition #2 etc), which essentially is just training on your whole data set, and it is certainly not cross-validation...

也就是说,在 CV 之后的最后一步 通常只暗示(并且初学者经常错过)是,在您对所选的超参数和/或模型性能感到满意之后您的 CV 程序,您返回并再次训练您的模型,这次使用全部 可用数据.

That said, a final step after the CV which is often only implied (and frequently missed by beginners) is that, after you are satisfied with your chosen hyperparameters and/or model performance as given by your CV procedure, you go back and train again your model, this time with the entire available data.

这篇关于Keras 中的交叉验证的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆