如何使用 TensorFlow 的数据集 API 多次迭代数据集? [英] How to iterate a dataset several times using TensorFlow's Dataset API?

查看:57
本文介绍了如何使用 TensorFlow 的数据集 API 多次迭代数据集?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

如何多次输出数据集中的值?(数据集由TensorFlow的Dataset API创建)

How to output the value in a dataset several times? (dataset is created by Dataset API of TensorFlow)

import tensorflow as tf

dataset = tf.contrib.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
epoch = 10

for i in range(epoch):
   for j in range(100):
      value = sess.run(next_element)
      assert j == value
      print(j)

错误信息:

tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]]

如何进行这项工作?

推荐答案

首先我建议你阅读 数据集指南.描述了DataSet API的所有细节.

First of all I advice you to read Data Set Guide. There is described all the details of DataSet API.

您的问题是关于多次迭代数据.这里有两个解决方案:

Your question is about iterating over the data several times. Here are two solutions for that:

  1. 一次迭代所有时期,没有关于各个时期结束的信息

import tensorflow as tf

epoch   = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.repeat(epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0
j = 0
while True:
    try:
        value = sess.run(next_element)
        assert j == value
        j += 1
        num_batch += 1
        if j > 99: # new epoch
            j = 0
    except tf.errors.OutOfRangeError:
        break

print ("Num Batch: ", num_batch)

  1. 第二个选项通知您有关结束每个时代的信息,因此您可以退出.检查验证损失:

import tensorflow as tf

epoch = 10
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0

for e in range(epoch):
    print ("Epoch: ", e)
    j = 0
    sess.run(iterator.initializer)
    while True:
        try:
            value = sess.run(next_element)
            assert j == value
            j += 1
            num_batch += 1
        except tf.errors.OutOfRangeError:
            break

print ("Num Batch: ", num_batch)

这篇关于如何使用 TensorFlow 的数据集 API 多次迭代数据集?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆