从TFRecordDataset获取数据集为numpy数组 [英] Get data set as numpy array from TFRecordDataset

查看:936
本文介绍了从TFRecordDataset获取数据集为numpy数组的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用新的 tf.data API 为CIFAR10数据集创建迭代器.我正在从两个 .tfrecord 文件中读取数据.一个保存训练数据(train.tfrecords),另一个保存测试数据(test.tfrecords).一切正常.但是,在某些时候,我需要两个数据集(训练数据和测试数据)为 numpy数组.

I'm using the new tf.data API to create an iterator for the CIFAR10 dataset. I'm reading the data from two .tfrecord files. One which holds the training data (train.tfrecords) and another one which holds the test data (test.tfrecords). This works all fine. At some point, however, I need both data sets (training data and test data) as numpy arrays.

是否可以从tf.data.TFRecordDataset对象检索作为numpy数组的数据集?

Is it possible to retrieve a data set as numpy array from a tf.data.TFRecordDataset object?

推荐答案

您可以使用 tf.data.Dataset.batch() 转换和 tf.contrib.data.get_single_element() 为此. 作为复习,dataset.batch(n) 将最多占用datasetn个连续元素,并通过串联每个组件将它们转换为一个元素.这要求所有元素的每个组件都具有固定的形状.如果n大于dataset中的元素数(或者如果n没有完全划分元素数),则最后一批可以较小.因此,您可以为n选择一个较大的值,然后执行以下操作:

You can use the tf.data.Dataset.batch() transformation and tf.contrib.data.get_single_element() to do this. As a refresher, dataset.batch(n) will take up to n consecutive elements of dataset and convert them into one element by concatenating each component. This requires all elements to have a fixed shape per component. If n is larger than the number of elements in dataset (or if n doesn't divide the number of elements exactly), then the last batch can be smaller. Therefore, you can choose a large value for n and do the following:

import numpy as np
import tensorflow as tf

# Insert your own code for building `dataset`. For example:
dataset = tf.data.TFRecordDataset(...)  # A dataset of tf.string records.
dataset = dataset.map(...)  # Extract components from each tf.string record.

# Choose a value of `max_elems` that is at least as large as the dataset.
max_elems = np.iinfo(np.int64).max
dataset = dataset.batch(max_elems)

# Extracts the single element of a dataset as one or more `tf.Tensor` objects.
# No iterator needed in this case!
whole_dataset_tensors = tf.contrib.data.get_single_element(dataset)

# Create a session and evaluate `whole_dataset_tensors` to get arrays.
with tf.Session() as sess:
    whole_dataset_arrays = sess.run(whole_dataset_tensors)

这篇关于从TFRecordDataset获取数据集为numpy数组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆