如何拆分 Tensorflow 数据集? [英] How do I split Tensorflow datasets?

查看:222
本文介绍了如何拆分 Tensorflow 数据集?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个基于 .tfrecord 文件的 tensorflow 数据集.如何将数据集拆分为测试和训练数据集?例如.70% 训练和 30% 测试?

I have a tensorflow dataset based on one .tfrecord file. How do I split the dataset into test and train datasets? E.g. 70% Train and 30% test?

我的 Tensorflow 版本:1.8我已经检查过,可能的副本中没有提到split_v"函数.我也在使用 tfrecord 文件.

My Tensorflow Version: 1.8 I've checked, there is no "split_v" function as mentioned in the possible duplicate. Also I am working with a tfrecord file.

推荐答案

你可以使用 Dataset.take()Dataset.skip():

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)

为了更通用,我举了一个使用 70/15/15 train/val/test split 的例子,但如果你不需要测试或验证集,只需忽略最后 2 行.

For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.

:

从该数据集中创建一个最多包含 count 个元素的数据集.

Creates a Dataset with at most count elements from this dataset.

跳过:

创建一个从该数据集中跳过计数元素的数据集.

Creates a Dataset that skips count elements from this dataset.

您可能还想查看 Dataset.shard():

You may also want to look into Dataset.shard():

创建一个仅包含此数据集的 1/num_shards 的数据集.

Creates a Dataset that includes only 1/num_shards of this dataset.

这篇关于如何拆分 Tensorflow 数据集?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆