如何从 tfrecords 目录创建 tf.data.dataset? [英] How to create tf.data.dataset from directories of tfrecords?

查看:40
本文介绍了如何从 tfrecords 目录创建 tf.data.dataset?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我的数据集有不同的目录,每个目录对应一个类.每个目录中有不同数量的 .tfrecords.我的问题是如何从每个目录中采样 5 个图像(每个 .tfrecord 文件对应一个图像)?我的另一个问题是如何对这些目录中的 5 个进行采样,然后从每个目录中采样 5 个图像.

My dataset has different directories and each directory is corresponding to one class. There are different numbers of .tfrecords in each directory. My question is that how can I sample 5 images (each .tfrecord file corresponds to one image) from each directory? My other question is that how can I sample 5 of these directories and then sample 5 images from each.

我只想用 tf.data.dataset 来做.所以我想要一个数据集,从中我可以得到一个迭代器,而 iterator.next() 给了我一批包含 5 个类的 5 个样本的 25 个图像.

I just want to do it with tf.data.dataset. So I want to have a dataset from which I get an iterator and that iterator.next() gives me a batch of 25 images containing 5 samples from 5 classes.

推荐答案

如果类的数量大于 5,则可以使用新的 tf.contrib.data.sample_from_datasets() API(目前在 tf-nightly 中可用,将在 TensorFlow 1.9 中可用).

If the number of classes is greater than 5, then you can use the new tf.contrib.data.sample_from_datasets() API (currently available in tf-nightly and will be available in TensorFlow 1.9).

directories = ["class_0/*", "class_1/*", "class_2/*", "class_3/*", ...]

CLASSES_PER_BATCH = 5
EXAMPLES_PER_CLASS_PER_BATCH = 5
BATCH_SIZE = CLASSES_PER_BATCH * EXAMPLES_PER_CLASS_PER_BATCH
NUM_CLASSES = len(directories)


# Build one dataset per class.
per_class_datasets = [
    tf.data.TFRecordDataset(tf.data.Dataset.list_files(d)) for d in directories]

# Next, build a dataset where each element is a vector of 5 classes to be chosen
# for a particular batch.
classes_per_batch_dataset = tf.contrib.data.Counter().map(
    lambda _: tf.random_shuffle(tf.range(NUM_CLASSES))[:CLASSES_PER_BATCH]))

# Transform the dataset of per-batch class vectors into a dataset with one
# one-hot element per example (i.e. 25 examples per batch).
class_dataset = classes_per_batch_dataset.flat_map(
    lambda classes: tf.data.Dataset.from_tensor_slices(
        tf.one_hot(classes, num_classes)).repeat(EXAMPLES_PER_CLASS_PER_BATCH))

# Use `tf.contrib.data.sample_from_datasets()` to select an example from the
# appropriate dataset in `per_class_datasets`.
example_dataset = tf.contrib.data.sample_from_datasets(per_class_datasets,
                                 class_dataset)

# Finally, combine 25 consecutive examples into a batch.
result = example_dataset.batch(BATCH_SIZE)

<小时>

如果您正好有 5 个类,您可以为每个目录定义一个嵌套数据集,并使用 Dataset.interleave():

# NOTE: We're assuming that the 0th directory contains elements from class 0, etc.
directories = ["class_0/*", "class_1/*", "class_2/*", "class_3/*", "class_4/*"]
directories = tf.data.Dataset.from_tensor_slices(directories)
directories = directories.apply(tf.contrib.data.enumerate_dataset())    

# Define a function that maps each (class, directory) pair to the (shuffled)
# records in those files.
def per_directory_dataset(class_label, directory_glob):
  files = tf.data.Dataset.list_files(directory_glob, shuffle=True)
  records = tf.data.TFRecordDataset(records)
  # Zip the records with their class. 
  # NOTE: This part might not be necessary if the records contain information about
  # their class that can be parsed from them.
  return tf.data.Dataset.zip(
      (records, tf.data.Dataset.from_tensors(class_label).repeat(None)))

# NOTE: The `cycle_length` and `block_length` here aren't strictly necessary,
# because the batch size is exactly `number of classes * images per class`.
# However, these arguments may be useful if you want to decouple these numbers.
merged_records = directories.interleave(per_directory_dataset,
                                        cycle_length=5, block_length=5)
merged_records = merged_records.batch(25)

这篇关于如何从 tfrecords 目录创建 tf.data.dataset?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆