使用Keras在gcloud ml-engine上处理TB级数据的最佳方法 [英] Best way to process terabytes of data on gcloud ml-engine with keras

查看:113
本文介绍了使用Keras在gcloud ml-engine上处理TB级数据的最佳方法的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想在gcloud存储上训练约2TB图像数据的模型.我将图像数据另存为单独的tfrecords,并尝试使用此示例中的tensorflow数据api

I want to train a model on about 2TB of image data on gcloud storage. I saved the image data as separate tfrecords and tried to use the tensorflow data api following this example

https://medium.com/@moritzkrger /speeding-up-keras-with-tfrecord-datasets-5464f9836c36

但是看来keras的model.fit(...)不支持基于的tfrecord数据集的验证

But it seems like keras' model.fit(...) doesn't support validation for tfrecord datasets based on

https://github.com/keras-team/keras/pull/8388

是否有更好的方法来处理我缺少的ml-engine的keras的大量数据?

Is there a better approach for processing large amounts of data with keras from ml-engine that I'm missing?

非常感谢!

推荐答案

如果您愿意使用tf.keras而不是实际的Keras,则可以使用tf.data API实例化TFRecordDataset并将其直接传递给model.fit(). 奖励:您可以直接从Google云端存储中进行流式传输,无需先下载数据:

If you are willing to use tf.keras instead of actual Keras, you can instantiate a TFRecordDataset with the tf.data API and pass that directly to model.fit(). Bonus: you get to stream directly from Google Cloud storage, no need to download the data first:

# Construct a TFRecordDataset
ds_train tf.data.TFRecordDataset('gs://') # path to TFRecords on GCS
ds_train = ds_train.shuffle(1000).batch(32)

model.fit(ds_train)

要包括验证数据,请使用验证TFRecords创建一个TFRecordDataset并将其传递给model.fit()validation_data参数.注意:从TensorFlow 1.9起,可能的.

To include validation data, create a TFRecordDataset with your validation TFRecords and pass that one to the validation_data argument of model.fit(). Note: this is possible as of TensorFlow 1.9.

最后的注释:您需要指定steps_per_epoch参数.我用来了解所有TFRecordfiles中示例总数的一种技巧是简单地遍历文件并计数:

Final note: you'll need to specify the steps_per_epoch argument. A hack that I use to know the total number of examples in all TFRecordfiles, is to simply iterate over the files and count:

import tensorflow as tf

def n_records(record_list):
    """Get the total number of records in a collection of TFRecords.
    Since a TFRecord file is intended to act as a stream of data,
    this needs to be done naively by iterating over the file and counting.
    See https://stackoverflow.com/questions/40472139

    Args:
        record_list (list): list of GCS paths to TFRecords files
    """
    counter = 0
    for f in record_list:
        counter +=\
            sum(1 for _ in tf.python_io.tf_record_iterator(f))
    return counter 

您可以用来计算steps_per_epoch

n_train = n_records([gs://path-to-tfrecords/record1,
                     gs://path-to-tfrecords/record2])

steps_per_epoch = n_train // batch_size

这篇关于使用Keras在gcloud ml-engine上处理TB级数据的最佳方法的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆