如何拆分 Tensorflow 数据集? [英] How do I split Tensorflow datasets?
问题描述
我有一个基于 .tfrecord 文件的 tensorflow 数据集.如何将数据集拆分为测试和训练数据集?例如.70% 训练和 30% 测试?
I have a tensorflow dataset based on one .tfrecord file. How do I split the dataset into test and train datasets? E.g. 70% Train and 30% test?
我的 Tensorflow 版本:1.8我已经检查过,可能的副本中没有提到split_v"函数.我也在使用 tfrecord 文件.
My Tensorflow Version: 1.8 I've checked, there is no "split_v" function as mentioned in the possible duplicate. Also I am working with a tfrecord file.
推荐答案
你可以使用 Dataset.take()
和 Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)
为了更通用,我举了一个使用 70/15/15 train/val/test split 的例子,但如果你不需要测试或验证集,只需忽略最后 2 行.
For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.
取:
从该数据集中创建一个最多包含 count 个元素的数据集.
Creates a Dataset with at most count elements from this dataset.
跳过:
创建一个从该数据集中跳过计数元素的数据集.
Creates a Dataset that skips count elements from this dataset.
您可能还想查看 Dataset.shard()
:
You may also want to look into Dataset.shard()
:
创建一个仅包含此数据集的 1/num_shards 的数据集.
Creates a Dataset that includes only 1/num_shards of this dataset.
这篇关于如何拆分 Tensorflow 数据集?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!