如何在神经网络中使用 k 折交叉验证 [英] How to use k-fold cross validation in a neural network

查看:177
本文介绍了如何在神经网络中使用 k 折交叉验证的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我们正在编写一个小型人工神经网络,它应该根据 10 个输入变量将 7000 种产品分为 7 个类别.

为了做到这一点,我们必须使用 k 折交叉验证,但我们有点困惑.

我们有演示幻灯片的摘录:

验证集和测试集到底是什么?

据我们了解,我们运行了 3 个训练集并调整了权重(单个 epoch).那么我们如何处理验证呢?因为据我了解,测试集是用来获取网络的误差的.

接下来发生的事情也让我感到困惑.什么时候进行交叉?

如果不是问太多,将不胜感激步骤的项目符号列表

解决方案

你似乎有点困惑(我记得我也是)所以我会为你简化一些事情.;)

示例神经网络场景

每当您接受诸如设计神经网络之类的任务时,您通常还会得到一个用于训练目的的样本数据集.让我们假设您正在训练一个简单的神经网络系统 Y = W · X 其中 Y 是通过计算权重向量的标量积 (·) 计算得到的输出 W 使用给定的样本向量 X.现在,最简单的方法是使用包含 1000 个样本的整个数据集来训练神经网络.假设训练收敛并且您的权重稳定,那么您可以有把握地说您的网络将对训练数据进行正确分类.但是,如果向网络提供以前看不见的数据会发生什么? 显然,此类系统的目的是能够概括和正确分类数据,而不是用于训练的数据.

过度拟合的解释

然而,在任何现实世界的情况下,只有在您的神经网络部署在生产环境中(我们称之为生产环境)后,以前看不见的/新数据才可用.但是,由于您还没有对其进行充分的测试,因此您可能会过得很不愉快.:) 任何学习系统几乎完美地匹配其训练集但经常因看不见的数据而失败的现象称为会给你既熟悉机器学习,又熟悉几种偏头痛.是否值得由你来决定.:)

We are writing a small ANN which is supposed to categorize 7000 products into 7 classes based on 10 input variables.

In order to do this we have to use k-fold cross validation but we are kind of confused.

We have this excerpt from the presentation slide:

What are exactly the validation and test sets?

From what we understand is that we run through the 3 training sets and adjust the weights (single epoch). Then what do we do with the validation? Because from what I understand is that the test set is used to get the error of the network.

What happens next is also confusing to me. When does the crossover take place?

If it's not too much to ask, a bullet list of step would be appreciated

解决方案

You seem to be a bit confused (I remember I was too) so I am going to simplify things for you. ;)

Sample Neural Network Scenario

Whenever you are given a task such as devising a neural network you are often also given a sample dataset to use for training purposes. Let us assume you are training a simple neural network system Y = W · X where Y is the output computed from calculating the scalar product (·) of the weight vector W with a given sample vector X. Now, the naive way to go about this would be using the entire dataset of, say, 1000 samples to train the neural network. Assuming that the training converges and your weights stabilise you can then safely say that you network will correctly classify the training data. But what happens to the network if presented with previously unseen data? Clearly the purpose of such systems is to be able to generalise and correctly classify data other than the one used for training.

Overfitting Explained

In any real-world situation, however, previously-unseen/new data is only available once your neural network is deployed in a, let's call it, production environment. But since you have not tested it adequately you are probably going to have a bad time. :) The phenomenon by which any learning system matches its training set almost perfectly but constantly fails with unseen data is called overfitting.

The Three Sets

Here come in the validation and testing parts of the algorithm. Let's go back to the original dataset of 1000 samples. What you do is you split it into three sets -- training, validation and testing (Tr, Va and Te) -- using carefully selected proportions. (80-10-10)% is usually a good proportion, where:

  • Tr = 80%
  • Va = 10%
  • Te = 10%

Training and Validation

Now what happens is that the neural network is trained on the Tr set and its weights are correctly updated. The validation set Va is then used to compute the classification error E = M - Y using the weights resulting from the training, where M is the expected output vector taken from the validation set and Y is the computed output resulting from the classification (Y = W * X). If the error is higher than a user-defined threshold then the whole training-validation epoch is repeated. This training phase ends when the error computed using the validation set is deemed low enough.

Smart Training

Now, a smart ruse here is to randomly select which samples to use for training and validation from the total set Tr + Va at each epoch iteration. This ensures that the network will not over-fit the training set.

Testing

The testing set Te is then used to measure the performance of the network. This data is perfect for this purpose as it was never used throughout the training and validation phase. It is effectively a small set of previously unseen data, which is supposed to mimic what would happen once the network is deployed in the production environment.

The performance is again measured in term of classification error as explained above. The performance can also (or maybe even should) be measured in terms of precision and recall so as to know where and how the error occurs, but that's the topic for another Q&A.

Cross-Validation

Having understood this training-validation-testing mechanism, one can further strengthen the network against over-fitting by performing K-fold cross-validation. This is somewhat an evolution of the smart ruse I explained above. This technique involves performing K rounds of training-validation-testing on, different, non-overlapping, equally-proportioned Tr, Va and Te sets.

Given k = 10, for each value of K you will split your dataset into Tr+Va = 90% and Te = 10% and you will run the algorithm, recording the testing performance.

k = 10
for i in 1:k
     # Select unique training and testing datasets
     KFoldTraining <-- subset(Data)
     KFoldTesting <-- subset(Data)

     # Train and record performance
     KFoldPerformance[i] <-- SmartTrain(KFoldTraining, KFoldTesting)

# Compute overall performance
TotalPerformance <-- ComputePerformance(KFoldPerformance)

Overfitting Shown

I am taking the world-famous plot below from wikipedia to show how the validation set helps prevent overfitting. The training error, in blue, tends to decrease as the number of epochs increases: the network is therefore attempting to match the training set exactly. The validation error, in red, on the other hand follows a different, u-shaped profile. The minimum of the curve is when ideally the training should be stopped as this is the point at which the training and validation error are lowest.

References

For more references this excellent book will give you both a sound knowledge of machine learning as well as several migraines. Up to you to decide if it's worth it. :)

这篇关于如何在神经网络中使用 k 折交叉验证的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆