使用 tf.data 批量处理来自多个 TFRecord 文件的顺序数据 [英] Batch sequential data coming from multiple TFRecord files with tf.data

查看:44
本文介绍了使用 tf.data 批量处理来自多个 TFRecord 文件的顺序数据的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

让我们考虑将数据集拆分为多个 TFRecord 文件:

Let's consider a dataset split into multiple TFRecord files:

  • 1.tfrecord,
  • 2.tfrecord,

我想生成大小为 t(比如 3)的序列,这些序列由来自同一个 TFRecord 文件的连续元素组成,我不希望序列包含属于到不同的 TFRecord 文件.

I would like to generate sequences of size t (say 3) consisting of consecutive elements from the same TFRecord file, I do not want a sequence to have elements belonging to different TFRecord files.

例如,如果我们有两个包含如下数据的 TFRecord 文件:

For instance, if we have two TFRecord files containing data like:

  • 1.tfrecord:{0, 1, 2, ..., 7}
  • 2.tfrecord:{1000, 1001, 1002, ..., 1007}

没有任何改组,我想得到以下批次:

without any shuffling, I would like to get the following batches:

  • 第一批:0, 1, 2,
  • 第二批:1, 2, 3,
  • ...
  • 第 i 个批次:5, 6, 7,
  • (i+1)-th 批次:1000, 1001, 1002,
  • (i+2)-th 批次:1001, 1002, 1003,
  • ...
  • 第j批次:1005, 1006, 1007,
  • (j+1)-th 批次:0, 1, 2,

我知道如何使用 tf.data.Dataset.windowtf.data.Dataset.batch 生成序列数据,但我不知道如何防止来自不同文件的包含元素的序列.

I know how to generate sequence data using tf.data.Dataset.window or tf.data.Dataset.batch, but I do not know how to prevent a sequence from containing element from different files.

我正在寻找一种可扩展的解决方案,即该解决方案应适用于数百个 TFRecord 文件.

I'm looking for a scalable solutions, i.e. the solution should work with hundred of TFRecord files.

以下是我失败的尝试(完全可重现的示例):

Below is my failed attempt (fully reproducible example):

import tensorflow as tf

# ****************************
# Generate toy TF Record files

def _create_example(i):
    example = tf.train.Features(feature={'data': tf.train.Feature(int64_list=tf.train.Int64List(value=[i]))})
    return tf.train.Example(features=example)

def parse_fn(serialized_example):
    return tf.parse_single_example(serialized_example, {'data': tf.FixedLenFeature([], tf.int64)})['data']


num_tf_records = 2
records_per_file = 8
options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
for i in range(num_tf_records):
    with tf.python_io.TFRecordWriter('%i.tfrecord' % i, options=options) as writer:
        for j in range(records_per_file):
            example = _create_example(j + 1000 * i)
            writer.write(example.SerializeToString())
# ****************************
# ****************************


data = tf.data.TFRecordDataset(['0.tfrecord', '1.tfrecord'], compression_type='GZIP')\
            .map(lambda x: parse_fn(x))

data = data.window(3, 1, 1, True)\
           .repeat(-1)\
           .flat_map(lambda x: x.batch(3))\
           .batch(16)

data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

with tf.Session() as sess:
    sess.run(data_it.initializer)
    print(sess.run(next_element))

输出:

[[   0    1    2]   # good
 [   1    2    3]   # good
 [   2    3    4]   # good
 [   3    4    5]   # good
 [   4    5    6]   # good
 [   5    6    7]   # good
 [   6    7 1000]   # bad – mix of elements from 0.tfrecord and 1.tfrecord
 [   7 1000 1001]   # bad
 [1000 1001 1002]   # good
 [1001 1002 1003]   # good
 [1002 1003 1004]   # good
 [1003 1004 1005]   # good
 [1004 1005 1006]   # good
 [1005 1006 1007]   # good
 [   0    1    2]   # good
 [   1    2    3]]  # good

推荐答案

我认为你只需要 flat_map 那个函数你必须制作 windo 数据集:

I think you just need to flat_map that function you have to make the windo datasets:

def make_dataset_from_filename(filename):
  data = tf.data.TFRecordDataset(filename, compression_type='GZIP')\
           .map(lambda x: parse_fn(x))

  data = data.window(3, 1, 1, True)\
             .repeat(-1)\
             .flat_map(lambda x: x.batch(3))\
             .batch(16)

tf.data.Dataset.list_files('*.tfrecord').flat_map(make_dataset_from_filename)

这篇关于使用 tf.data 批量处理来自多个 TFRecord 文件的顺序数据的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆