Keras model.fit()与tf.dataset API + validation_data [英] Keras model.fit() with tf.dataset API + validation_data

查看:784
本文介绍了Keras model.fit()与tf.dataset API + validation_data的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

所以我通过以下代码使我的keras模型可以与tf.Dataset一起使用:

So I have got my keras model to work with a tf.Dataset through the following code:

# Initialize batch generators(returns tf.Dataset)
batch_train = build_features.get_train_batches(batch_size=batch_size)

# Create TensorFlow Iterator object
iterator = batch_train.make_one_shot_iterator()
dataset_inputs, dataset_labels = iterator.get_next()

# Create Model
logits = .....(some layers)
keras.models.Model(inputs=dataset_inputs, outputs=logits)

# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit(epochs=epochs, steps_per_epoch=num_batches, callbacks=callbacks, verbose=1)

但是,当我尝试将validation_data参数传递给模型时.适合它告诉我不能与发电机一起使用.有没有一种方法可以在使用tf.Dataset时使用验证

however when I try to pass validation_data parameter to the model. fit it tells me that I cannot use it with the generator. Is there a way to use validation while using tf.Dataset

例如在tensorflow中,我可以执行以下操作:

# initialize batch generators
batch_train = build_features.get_train_batches(batch_size=batch_size)
batch_valid = build_features.get_valid_batches(batch_size=batch_size)

# create TensorFlow Iterator object
iterator = tf.data.Iterator.from_structure(batch_train.output_types,
                                           batch_train.output_shapes)

# create two initialization ops to switch between the datasets
init_op_train = iterator.make_initializer(batch_train)
init_op_valid = iterator.make_initializer(batch_valid)

然后只需使用sess.run(init_op_train)sess.run(init_op_valid)在数据集之间切换

then just use sess.run(init_op_train) and sess.run(init_op_valid) to switch between the datasets

我尝试实现一个仅执行此操作的回调(切换到验证集,进行预测并返回),但它告诉我不能在回调中使用model.predict

I tried implementing a callback that does just that (switch to validation set, predict and back) but it tells me I can't use model.predict in a callback

有人可以帮助我使用Keras + Tf.Dataset进行验证

can someone help me get validation working with Keras+Tf.Dataset

最后,对我来说最有用的是,归功于所选的答案:

# Initialize batch generators(returns tf.Dataset)
batch_train = # returns tf.Dataset
batch_valid = # returns tf.Dataset

# Create TensorFlow Iterator object and wrap it in a generator
itr_train = make_iterator(batch_train)
itr_valid = make_iterator(batch_train)

# Create Model
logits = # the keras model
keras.models.Model(inputs=dataset_inputs, outputs=logits)

# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit_generator(
    generator=itr_train, validation_data=itr_valid, validation_steps=batch_size,
    epochs=epochs, steps_per_epoch=num_batches, callbacks=cbs, verbose=1, workers=0)

def make_iterator(dataset):
    iterator = dataset.make_one_shot_iterator()
    next_val = iterator.get_next()

    with K.get_session().as_default() as sess:
        while True:
            *inputs, labels = sess.run(next_val)
            yield inputs, labels

这不会带来任何间接费用

推荐答案

我通过使用fit_genertor解决了这个问题.我在此处找到了解决方案.我应用了@ Dat-Nguyen的解决方案.

I solved the problem by using fit_genertor. I found the solution here. I applied @Dat-Nguyen's solution.

您只需创建两个迭代器,一个用于训练,一个用于验证,然后创建您自己的生成器,即可从数据集中提取批次,并以(batch_data,batch_labels)的形式提供数据.最后,在model.fit_generator中,您将传递train_generator和validation_generator.

You need simply to create two iterators, one for training and one for validation and then create your own generator where you will extract batches from the dataset and provide the data in form of (batch_data, batch_labels) . Finally in model.fit_generator you will pass the train_generator and validation_generator.

这篇关于Keras model.fit()与tf.dataset API + validation_data的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆