Keras model.fit()与tf.dataset API + validation_data [英] Keras model.fit() with tf.dataset API + validation_data
问题描述
所以我通过以下代码使我的keras模型可以与tf.Dataset一起使用:
So I have got my keras model to work with a tf.Dataset through the following code:
# Initialize batch generators(returns tf.Dataset)
batch_train = build_features.get_train_batches(batch_size=batch_size)
# Create TensorFlow Iterator object
iterator = batch_train.make_one_shot_iterator()
dataset_inputs, dataset_labels = iterator.get_next()
# Create Model
logits = .....(some layers)
keras.models.Model(inputs=dataset_inputs, outputs=logits)
# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit(epochs=epochs, steps_per_epoch=num_batches, callbacks=callbacks, verbose=1)
但是,当我尝试将validation_data
参数传递给模型时.适合它告诉我不能与发电机一起使用.有没有一种方法可以在使用tf.Dataset时使用验证
however when I try to pass validation_data
parameter to the model. fit it tells me that I cannot use it with the generator. Is there a way to use validation while using tf.Dataset
例如在tensorflow中,我可以执行以下操作:
# initialize batch generators
batch_train = build_features.get_train_batches(batch_size=batch_size)
batch_valid = build_features.get_valid_batches(batch_size=batch_size)
# create TensorFlow Iterator object
iterator = tf.data.Iterator.from_structure(batch_train.output_types,
batch_train.output_shapes)
# create two initialization ops to switch between the datasets
init_op_train = iterator.make_initializer(batch_train)
init_op_valid = iterator.make_initializer(batch_valid)
然后只需使用sess.run(init_op_train)
和sess.run(init_op_valid)
在数据集之间切换
then just use sess.run(init_op_train)
and sess.run(init_op_valid)
to switch between the datasets
我尝试实现一个仅执行此操作的回调(切换到验证集,进行预测并返回),但它告诉我不能在回调中使用model.predict
I tried implementing a callback that does just that (switch to validation set, predict and back) but it tells me I can't use model.predict in a callback
有人可以帮助我使用Keras + Tf.Dataset进行验证
can someone help me get validation working with Keras+Tf.Dataset
最后,对我来说最有用的是,归功于所选的答案:
# Initialize batch generators(returns tf.Dataset)
batch_train = # returns tf.Dataset
batch_valid = # returns tf.Dataset
# Create TensorFlow Iterator object and wrap it in a generator
itr_train = make_iterator(batch_train)
itr_valid = make_iterator(batch_train)
# Create Model
logits = # the keras model
keras.models.Model(inputs=dataset_inputs, outputs=logits)
# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit_generator(
generator=itr_train, validation_data=itr_valid, validation_steps=batch_size,
epochs=epochs, steps_per_epoch=num_batches, callbacks=cbs, verbose=1, workers=0)
def make_iterator(dataset):
iterator = dataset.make_one_shot_iterator()
next_val = iterator.get_next()
with K.get_session().as_default() as sess:
while True:
*inputs, labels = sess.run(next_val)
yield inputs, labels
这不会带来任何间接费用
推荐答案
我通过使用fit_genertor解决了这个问题.我在此处找到了解决方案.我应用了@ Dat-Nguyen的解决方案.
I solved the problem by using fit_genertor. I found the solution here. I applied @Dat-Nguyen's solution.
您只需创建两个迭代器,一个用于训练,一个用于验证,然后创建您自己的生成器,即可从数据集中提取批次,并以(batch_data,batch_labels)的形式提供数据.最后,在model.fit_generator中,您将传递train_generator和validation_generator.
You need simply to create two iterators, one for training and one for validation and then create your own generator where you will extract batches from the dataset and provide the data in form of (batch_data, batch_labels) . Finally in model.fit_generator you will pass the train_generator and validation_generator.
这篇关于Keras model.fit()与tf.dataset API + validation_data的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!