Keras网络无法对最后一类进行分类 [英] Keras network can never classify the last class

查看:55
本文介绍了Keras网络无法对最后一类进行分类的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我一直在做我的项目深度学习语言检测,这是一个与这些层建立联系以从16种编程语言中进行识别:

I have been working on my project Deep Learning Language Detection which is a network with these layers to recognise from 16 programming languages:

这是产生网络的代码:

# Setting up the model
graph_in = Input(shape=(sequence_length, number_of_quantised_characters))
convs = []
for i in range(0, len(filter_sizes)):
    conv = Conv1D(filters=num_filters,
                  kernel_size=filter_sizes[i],
                  padding='valid',
                  activation='relu',
                  strides=1)(graph_in)
    pool = MaxPooling1D(pool_size=pooling_sizes[i])(conv)
    flatten = Flatten()(pool)
    convs.append(flatten)

if len(filter_sizes)>1:
    out = Concatenate()(convs)
else:
    out = convs[0]

graph = Model(inputs=graph_in, outputs=out)

# main sequential model
model = Sequential()


model.add(Dropout(dropout_prob[0], input_shape=(sequence_length, number_of_quantised_characters)))
model.add(graph)
model.add(Dense(hidden_dims))
model.add(Dropout(dropout_prob[1]))
model.add(Dense(number_of_classes))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=['accuracy'])

因此,我的最后一门语言课是SQL,在测试阶段,它永远无法正确预测SQL,并且对它们的得分为0%.我以为这是由于SQL样本的质量差(实际上是很差的),所以我取消了该课程,并开始进行15个课程的培训.令我惊讶的是,现在F#文件的检测率为0%,而F#是删除SQL之后的最后一个类(即,最后一个位置为1,其余为0的单热向量).现在,如果将经过16次训练的网络与15进行比较,则可以实现98.5%的非常高的成功率.

So my last language class is SQL and in the test phase, it can never predict SQL correctly and it scores 0% on them. I thought this was due to poor quality of SQL samples (and indeed they were poor) so I removed this class and started training on 15 classes. To my surprise, now F# files had 0% detection and F# was the last class after removing SQL (i.e. the one-hot-vector where the last position is 1 and the rest is 0). Now if a network that was trained on 16 used against 15, it would achieve a very high success rate of 98.5%.

我使用的代码非常简单,主要在 data_helper. py

The code that I am using is pretty simple and available mainly in defs.py and data_helper.py

这是经过16堂课和16堂课测试的网络训练的结果:

Here is the result of network trained with 16 classes tested against 16 classes:

Final result: 14827/16016 (0.925761738262)
xml:        995/1001 (0.994005994006)
fsharp:     974/1001 (0.973026973027)
clojure:        993/1001 (0.992007992008)
java:       996/1001 (0.995004995005)
scala:      990/1001 (0.989010989011)
python:     983/1001 (0.982017982018)
sql:        0/1001 (0.0)
js:     991/1001 (0.99000999001)
cpp:        988/1001 (0.987012987013)
css:        987/1001 (0.986013986014)
csharp:     994/1001 (0.993006993007)
go:     989/1001 (0.988011988012)
php:        998/1001 (0.997002997003)
ruby:       995/1001 (0.994005994006)
powershell:     992/1001 (0.991008991009)
bash:       962/1001 (0.961038961039)

这是同一网络(针对16个进行训练)针对15个班级进行测试的结果:

And this is the result of the same network (trained against 16) ran against 15 classes:

Final result: 14827/15015 (0.987479187479)
xml:        995/1001 (0.994005994006)
fsharp:     974/1001 (0.973026973027)
clojure:        993/1001 (0.992007992008)
java:       996/1001 (0.995004995005)
scala:      990/1001 (0.989010989011)
python:     983/1001 (0.982017982018)
js:     991/1001 (0.99000999001)
cpp:        988/1001 (0.987012987013)
css:        987/1001 (0.986013986014)
csharp:     994/1001 (0.993006993007)
go:     989/1001 (0.988011988012)
php:        998/1001 (0.997002997003)
ruby:       995/1001 (0.994005994006)
powershell:     992/1001 (0.991008991009)
bash:       962/1001 (0.961038961039)

还有其他人看到过吗?我该如何解决?

Has anyone else seen this? How can I get around it?

推荐答案

TL; DR:问题是,在将数据分为训练集和验证集之前,数据没有进行混洗.因此,在训练期间,所有属于"sql"类的样本都在验证集中.如果在该课程中没有提供最后的课程,您的模型将不会学习预测该课程.

TL;DR: The problem is that your data are not shuffled before being split into training and validation sets. Therefore, during training, all samples belonging to class "sql" are in the validation set. Your model won't learn to predict the last class if it hasn't been given samples in that class.

get_input_and_labels()中,首先加载0类的文件,然后再加载1类,依此类推.由于您设置了n_max_files = 2000,这意味着

In get_input_and_labels(), the files for class 0 are first loaded, and then class 1, and so on. Since you set n_max_files = 2000, it means that

  • Y中的前2000个(或大约取决于您实际拥有的文件数)条目将属于类0(执行")
  • 接下来的2000个条目将属于1类("csharp")
  • ...
  • 最后,最后2000个条目将属于最后一个类("sql").
  • The first 2000 (or so, depends on how many files you actually have) entries in Y will be of class 0 ("go")
  • The next 2000 entries will be of class 1 ("csharp")
  • ...
  • and finally the last 2000 entries will be of the last class ("sql").

不幸的是,Keras不会在将数据分为训练和验证集之前对数据进行洗牌.由于您的代码中validation_split设置为0.1,因此验证集中将包含大约最后3000个样本(包含所有"sql"样本).

Unfortunately, Keras does not shuffle the data before splitting them into training and validation sets. Because validation_split is set to 0.1 in your code, about the last 3000 samples (which contains all the "sql" samples) will be in the validation set.

如果将validation_split设置为较高的值(例如0.2),则会看到得分为0%的更多类别:

If you set validation_split to a higher value (e.g., 0.2), you'll see more classes scoring 0%:

Final result: 12426/16016 (0.7758491508491508)
go:             926/1001 (0.9250749250749251)
csharp:         966/1001 (0.965034965034965)
java:           973/1001 (0.972027972027972)
js:             929/1001 (0.9280719280719281)
cpp:            986/1001 (0.985014985014985)
ruby:           942/1001 (0.9410589410589411)
powershell:             981/1001 (0.98001998001998)
bash:           882/1001 (0.8811188811188811)
php:            977/1001 (0.9760239760239761)
css:            988/1001 (0.987012987012987)
xml:            994/1001 (0.993006993006993)
python:         986/1001 (0.985014985014985)
scala:          896/1001 (0.8951048951048951)
clojure:                0/1001 (0.0)
fsharp:         0/1001 (0.0)
sql:            0/1001 (0.0)


如果在装入后重新整理数据,则可以解决此问题.似乎您已经有一些行在对数据进行混排:


The problem can be solved if you shuffle the data after loading. It seems that you already have lines shuffling the data:

# Shuffle data
shuffle_indices = np.random.permutation(np.arange(len(y)))
x_shuffled = x[shuffle_indices]
y_shuffled = y[shuffle_indices].argmax(axis=1)

但是,当您拟合模型时,您将原始的xy传递给了fit()而不是x_shuffledy_shuffled.如果将行更改为:

However, when you fit the model, you passed the original x and y to fit() instead of x_shuffled and y_shuffled. If you change the line into:

model.fit(x_shuffled, y_shuffled, batch_size=batch_size,
          epochs=num_epochs, validation_split=val_split, verbose=1)

测试输出将变得更加合理:

The testing output would become more reasonable:

Final result: 15248/16016 (0.952047952047952)
go:             865/1001 (0.8641358641358642)
csharp:         986/1001 (0.985014985014985)
java:           977/1001 (0.9760239760239761)
js:             953/1001 (0.952047952047952)
cpp:            974/1001 (0.973026973026973)
ruby:           985/1001 (0.984015984015984)
powershell:             974/1001 (0.973026973026973)
bash:           942/1001 (0.9410589410589411)
php:            979/1001 (0.978021978021978)
css:            965/1001 (0.964035964035964)
xml:            988/1001 (0.987012987012987)
python:         857/1001 (0.8561438561438561)
scala:          955/1001 (0.954045954045954)
clojure:                985/1001 (0.984015984015984)
fsharp:         950/1001 (0.949050949050949)
sql:            913/1001 (0.9120879120879121)

这篇关于Keras网络无法对最后一类进行分类的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆