在Keras中每N个时期更改训练数据集 [英] Change training dataset every N epochs in Keras
问题描述
我想在Keras中的每个N
时期传递另一个训练数据集(X_train, y_train)
,其中(X_train, y_train)
是通过蒙特卡洛模拟获得的.
I would like to pass another training dataset (X_train, y_train)
every N
epochs in Keras, where (X_train, y_train)
are obtained through Monte Carlo simulations.
在伪代码中,它将通过以下方式完成:
In pseudo-code, it would be done by:
for i in range(nb_total_epochs):
if i%N == 0:
X_train, y_train = generate_new_dataset(simulation_parameters)
train_model(X_train, y_train)
是否存在使用fit()
函数实现此目的的技巧?
Is there any existing trick to achieve this with the fit()
function?
推荐答案
使用 Sequence
进行创建数据集并将其传递给 fit_generator
.定义on_epoch_end
方法以在某些时期修改数据集.
Use Sequence
to create your dataset and pass it to fit_generator
. Define the on_epoch_end
method to modify the dataset on certain epochs.
每个
Sequence
必须实现__getitem__
和__len__
方法. 如果您想在各个时期之间修改数据集,则可以实现on_epoch_end
.方法__getitem__
应该返回完整的批处理.
Every
Sequence
must implements the__getitem__
and the__len__
methods. If you want to modify your dataset between epochs you may implementon_epoch_end
. The method__getitem__
should return a complete batch.
此外,您可以安全地将Sequence
用于多处理数据处理:
Also, you can safely use Sequence
with multiprocessing data processing:
使用
keras.utils.Sequence
可以保证顺序,并保证在使用use_multiprocessing=True
时每个历元的每个输入都可以使用.
The use of
keras.utils.Sequence
guarantees the ordering and guarantees the single use of every input per epoch when usinguse_multiprocessing=True
.
示例
从Sequence
文档中进行了略微修改,以包含on_epoch_end
.
Example
Slightly modified from the Sequence
documentation to include on_epoch_end
.
class CIFAR10Sequence(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.epoch = 0
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
def on_epoch_end(self):
if self.epoch % N == 0:
pass
# modify data
self.epoch += 1
这篇关于在Keras中每N个时期更改训练数据集的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!