具有BatchNormalization的模型:停滞的测试损失 [英] Model with BatchNormalization: stagnant test loss

查看:128
本文介绍了具有BatchNormalization的模型:停滞的测试损失的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我使用Keras编写了一个神经网络.它包含BatchNormalization层.

I wrote a neural network using Keras. It contains BatchNormalization layers.

当我用model.fit训练它时,一切都很好.在使用tensorflow训练它时如此处所述,训练很好,但是验证步骤始终会产生非常差的效果,并且很快就会饱和(准确度达到5%,10%,40%,40%, 40%..;损失也停滞不前).

When I trained it with model.fit, everything was fine. When training it with tensorflow as explained here, the training is fine, but the validation step always give very poor performance, and it quickly saturates (the accuracy goes 5%, 10%, 40%, 40%, 40%..; the loss is stagnant too).

我需要使用tensorflow,因为它在培训的监视部分提供了更大的灵活性.

I need to use tensorflow because it allows more flexibility regarding the monitoring part of training.

我强烈怀疑它与BN层或/和我计算测试性能的方式有关(见下文)

I strongly suspect it has something to do with BN layers or/and the way I compute the test performances (see below)

feed_dict = {x: X_valid,
            batch_size_placeholder: X_valid.shape[0],
            K.learning_phase(): 0,
            beta: self.warm_up_schedule(global_step)
            }
if self.weights is not None:
    feed_dict[weights] = self.weights
acc = accuracy.eval(feed_dict=feed_dict)

计算包含Keras BatchNormalizatin层的模型的验证准确性时,有什么特别的事情要做吗?

Is there anything special to do when computing the validation accuracy of a model containing Keras BatchNormalizatin layers ?

提前谢谢!

推荐答案

实际上,我发现了

Actually I found out about the training argument of the __call__ method of the BatchNormalization layer

因此实例化图层时可以做的只是:

So what you can do when instantiating the layer is just:

x = Input((dim1, dim2))
h = Dense(dim3)(x)
h = BatchNormalization()(h, training=K.learning_phase())

在评估验证集的性能时:

And when evaluating the performance on validation set:

feed_dict = {x: X_valid,
             batch_size_placeholder: X_valid.shape[0],
             K.learning_phase(): 0,
             beta: self.warm_up_schedule(global_step)
             }
acc = accuracy.eval(feed_dict=feed_dict)
summary_ = merged.eval(feed_dict=feed_dict)
test_writer.add_summary(summary_, global_step)

这篇关于具有BatchNormalization的模型:停滞的测试损失的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆