从TFRecordDataset获取数据集为numpy数组 [英] Get data set as numpy array from TFRecordDataset
问题描述
我正在使用新的 tf.data
API 为CIFAR10数据集创建迭代器.我正在从两个 .tfrecord 文件中读取数据.一个保存训练数据(train.tfrecords),另一个保存测试数据(test.tfrecords).一切正常.但是,在某些时候,我需要两个数据集(训练数据和测试数据)为 numpy数组.
I'm using the new tf.data
API to create an iterator for the CIFAR10 dataset. I'm reading the data from two .tfrecord files. One which holds the training data (train.tfrecords) and another one which holds the test data (test.tfrecords). This works all fine. At some point, however, I need both data sets (training data and test data) as numpy arrays.
是否可以从tf.data.TFRecordDataset
对象检索作为numpy数组的数据集?
Is it possible to retrieve a data set as numpy array from a tf.data.TFRecordDataset
object?
推荐答案
您可以使用 tf.data.Dataset.batch()
转换和 tf.contrib.data.get_single_element()
为此.
作为复习,dataset.batch(n)
将最多占用dataset
的n
个连续元素,并通过串联每个组件将它们转换为一个元素.这要求所有元素的每个组件都具有固定的形状.如果n
大于dataset
中的元素数(或者如果n
没有完全划分元素数),则最后一批可以较小.因此,您可以为n
选择一个较大的值,然后执行以下操作:
You can use the tf.data.Dataset.batch()
transformation and tf.contrib.data.get_single_element()
to do this.
As a refresher, dataset.batch(n)
will take up to n
consecutive elements of dataset
and convert them into one element by concatenating each component. This requires all elements to have a fixed shape per component. If n
is larger than the number of elements in dataset
(or if n
doesn't divide the number of elements exactly), then the last batch can be smaller. Therefore, you can choose a large value for n
and do the following:
import numpy as np
import tensorflow as tf
# Insert your own code for building `dataset`. For example:
dataset = tf.data.TFRecordDataset(...) # A dataset of tf.string records.
dataset = dataset.map(...) # Extract components from each tf.string record.
# Choose a value of `max_elems` that is at least as large as the dataset.
max_elems = np.iinfo(np.int64).max
dataset = dataset.batch(max_elems)
# Extracts the single element of a dataset as one or more `tf.Tensor` objects.
# No iterator needed in this case!
whole_dataset_tensors = tf.contrib.data.get_single_element(dataset)
# Create a session and evaluate `whole_dataset_tensors` to get arrays.
with tf.Session() as sess:
whole_dataset_arrays = sess.run(whole_dataset_tensors)
这篇关于从TFRecordDataset获取数据集为numpy数组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!