如何在Keras的培训课程中保留指标值? [英] How to preserve metric values over training sessions in Keras?

查看:73
本文介绍了如何在Keras的培训课程中保留指标值?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个fit()函数,该函数使用ModelCheckpoint()回调保存模型(如果它比以前的模型更好),并使用save_weights_only = False,因此可以保存整个模型.这应该允许我以后使用load_model()恢复训练.

I have a fit() function that uses the ModelCheckpoint() callback to save the model if it better than any previous model, using save_weights_only=False, so it saves the entire model. This should allow me to resume training at a later date by using load_model().

不幸的是,在save()/load_model()往返中的某个地方,度量标准并未保留-例如,val_loss设置为inf.这意味着当训练恢复时,在第一个时期之后,ModelCheckpoint()将始终保存模型,这几乎总是比上一届的前冠军更糟糕.

Unfortunately, somewhere in the save()/load_model() roundtrip, the metric values are not preserved -- for example, val_loss is set to inf. This means that when training resumes, after the first epoch ModelCheckpoint() will always save the model, which will almost always be worse than the previous champion from the earlier session.

我确定可以在继续训练之前设置ModelCheckpoint()的当前最佳值,如下所示:

I have determined that I can set ModelCheckpoint()'s current best value before resuming training, as follows:

myCheckpoint = ModelCheckpoint(...)
myCheckpoint.best = bestValueSoFar

很明显,我可以监视所需的值,将它们写到文件中,然后在恢复时再次读取它们,但是考虑到我是Keras新手,我想知道自己是否错过了明显的事情.

Obviously, I could monitor the values I need, write them out to a file, and read them in again when I resume, but given that I am a Keras newbie, I am wondering if I have missed something obvious.

推荐答案

我最终很快编写了自己的回调,该回调可跟踪最佳训练值,因此我可以重新加载它们.看起来像这样:

I ended up quickly writing my own callback that keeps track of the best training values so I can reload them. It looks like this:

# State monitor callback. Tracks how well we are doing and writes
# some state to a json file. This lets us resume training seamlessly.
#
# ModelState.state is:
#
# { "epoch_count": nnnn,
#   "best_values": { dictionary with keys for each log value },
#   "best_epoch": { dictionary with keys for each log value }
# }

class ModelState(callbacks.Callback):

    def __init__(self, state_path):

        self.state_path = state_path

        if os.path.isfile(state_path):
            print('Loading existing .json state')
            with open(state_path, 'r') as f:
                self.state = json.load(f)
        else:
            self.state = { 'epoch_count': 0,
                           'best_values': {},
                           'best_epoch': {}
                         }

    def on_train_begin(self, logs={}):

        print('Training commences...')

    def on_epoch_end(self, batch, logs={}):

        # Currently, for everything we track, lower is better

        for k in logs:
            if k not in self.state['best_values'] or logs[k] < self.state['best_values'][k]:
                self.state['best_values'][k] = float(logs[k])
                self.state['best_epoch'][k] = self.state['epoch_count']

        with open(self.state_path, 'w') as f:
            json.dump(self.state, f, indent=4)
        print('Completed epoch', self.state['epoch_count'])

        self.state['epoch_count'] += 1

然后在fit()函数中,如下所示:

Then, in the fit() function, something like this:

# Set up the model state, reading in prior results info if available

model_state = ModelState(path_to_state_file)

# Checkpoint the model if we get a better result

model_checkpoint = callbacks.ModelCheckpoint(path_to_model_file,
                                             monitor='val_loss',
                                             save_best_only=True,
                                             verbose=1,
                                             mode='min',
                                             save_weights_only=False)


# If we have trained previously, set up the model checkpoint so it won't save
# until it finds something better. Otherwise, it would always save the results
# of the first epoch.

if 'best_values' in model_state.state:
    model_checkpoint.best = model_state.state['best_values']['val_loss']

callback_list = [model_checkpoint,
                model_state]

# Offset epoch counts if we are resuming training. If you don't do
# this, only epochs-initial_epochs epochs will be done.

initial_epoch = model_state.state['epoch_count']
epochs += initial_epoch

# .fit() or .fit_generator, etc. goes here.

这篇关于如何在Keras的培训课程中保留指标值?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆