Keras model.fit() 与 tf.dataset API + validation_data [英] Keras model.fit() with tf.dataset API + validation_data

查看:29
本文介绍了Keras model.fit() 与 tf.dataset API + validation_data的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

所以我通过以下代码让我的 keras 模型与 tf.Dataset 一起工作:

So I have got my keras model to work with a tf.Dataset through the following code:

# Initialize batch generators(returns tf.Dataset)
batch_train = build_features.get_train_batches(batch_size=batch_size)

# Create TensorFlow Iterator object
iterator = batch_train.make_one_shot_iterator()
dataset_inputs, dataset_labels = iterator.get_next()

# Create Model
logits = .....(some layers)
keras.models.Model(inputs=dataset_inputs, outputs=logits)

# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit(epochs=epochs, steps_per_epoch=num_batches, callbacks=callbacks, verbose=1)

但是,当我尝试将 validation_data 参数传递给模型时.适合它告诉我我不能将它与发电机一起使用.有没有办法在使用 tf.Dataset 时使用验证

however when I try to pass validation_data parameter to the model. fit it tells me that I cannot use it with the generator. Is there a way to use validation while using tf.Dataset

例如在 tensorflow 中,我可以执行以下操作:

# initialize batch generators
batch_train = build_features.get_train_batches(batch_size=batch_size)
batch_valid = build_features.get_valid_batches(batch_size=batch_size)

# create TensorFlow Iterator object
iterator = tf.data.Iterator.from_structure(batch_train.output_types,
                                           batch_train.output_shapes)

# create two initialization ops to switch between the datasets
init_op_train = iterator.make_initializer(batch_train)
init_op_valid = iterator.make_initializer(batch_valid)

然后只需使用 sess.run(init_op_train)sess.run(init_op_valid) 在数据集之间切换

then just use sess.run(init_op_train) and sess.run(init_op_valid) to switch between the datasets

我尝试实现一个回调来做到这一点(切换到验证集,预测并返回),但它告诉我我不能在回调中使用 model.predict

I tried implementing a callback that does just that (switch to validation set, predict and back) but it tells me I can't use model.predict in a callback

有人可以帮助我使用 Keras+Tf.Dataset 进行验证

can someone help me get validation working with Keras+Tf.Dataset

所以最终对我有用的方法是:

# Initialize batch generators(returns tf.Dataset)
batch_train = # returns tf.Dataset
batch_valid = # returns tf.Dataset

# Create TensorFlow Iterator object and wrap it in a generator
itr_train = make_iterator(batch_train)
itr_valid = make_iterator(batch_train)

# Create Model
logits = # the keras model
keras.models.Model(inputs=dataset_inputs, outputs=logits)

# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit_generator(
    generator=itr_train, validation_data=itr_valid, validation_steps=batch_size,
    epochs=epochs, steps_per_epoch=num_batches, callbacks=cbs, verbose=1, workers=0)

def make_iterator(dataset):
    iterator = dataset.make_one_shot_iterator()
    next_val = iterator.get_next()

    with K.get_session().as_default() as sess:
        while True:
            *inputs, labels = sess.run(next_val)
            yield inputs, labels

这不会引入任何开销

推荐答案

我使用 fit_genertor 解决了这个问题.我在此处找到了解决方案.我应用了@Dat-Nguyen 的解决方案.

I solved the problem by using fit_genertor. I found the solution here. I applied @Dat-Nguyen's solution.

您只需创建两个迭代器,一个用于训练,一个用于验证,然后创建您自己的生成器,您将从数据集中提取批次并以 (batch_data, batch_labels) 的形式提供数据.最后在model.fit_generator 中,您将通过train_generator 和validation_generator.

You need simply to create two iterators, one for training and one for validation and then create your own generator where you will extract batches from the dataset and provide the data in form of (batch_data, batch_labels) . Finally in model.fit_generator you will pass the train_generator and validation_generator.

这篇关于Keras model.fit() 与 tf.dataset API + validation_data的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆