在Keras中每N个时期更改训练数据集 [英] Change training dataset every N epochs in Keras

查看:507
本文介绍了在Keras中每N个时期更改训练数据集的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想在Keras中的每个N时期传递另一个训练数据集(X_train, y_train),其中(X_train, y_train)是通过蒙特卡洛模拟获得的.

I would like to pass another training dataset (X_train, y_train) every N epochs in Keras, where (X_train, y_train) are obtained through Monte Carlo simulations.

在伪代码中,它将通过以下方式完成:

In pseudo-code, it would be done by:

for i in range(nb_total_epochs):
    if i%N == 0:
       X_train, y_train = generate_new_dataset(simulation_parameters)
    train_model(X_train, y_train)

是否存在使用fit()函数实现此目的的技巧?

Is there any existing trick to achieve this with the fit() function?

推荐答案

使用 Sequence 进行创建数据集并将其传递给 fit_generator .定义on_epoch_end方法以在某些时期修改数据集.

Use Sequence to create your dataset and pass it to fit_generator. Define the on_epoch_end method to modify the dataset on certain epochs.

每个Sequence必须实现__getitem____len__方法. 如果您想在各个时期之间修改数据集,则可以实现on_epoch_end.方法__getitem__应该返回完整的批处理.

Every Sequence must implements the __getitem__ and the __len__ methods. If you want to modify your dataset between epochs you may implement on_epoch_end. The method __getitem__ should return a complete batch.

此外,您可以安全地将Sequence用于多处理数据处理:

Also, you can safely use Sequence with multiprocessing data processing:

使用keras.utils.Sequence可以保证顺序,并保证在使用use_multiprocessing=True时每个历元的每个输入都可以使用.

The use of keras.utils.Sequence guarantees the ordering and guarantees the single use of every input per epoch when using use_multiprocessing=True.

示例

Sequence文档中进行了略微修改,以包含on_epoch_end.

Example

Slightly modified from the Sequence documentation to include on_epoch_end.

class CIFAR10Sequence(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.epoch = 0
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        return np.array([
            resize(imread(file_name), (200, 200))
               for file_name in batch_x]), np.array(batch_y)

    def on_epoch_end(self):
        if self.epoch % N == 0:
            pass
            # modify data
        self.epoch += 1

这篇关于在Keras中每N个时期更改训练数据集的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆