Tensorflow Dataset API 在完成一个 epoch 后恢复 Iterator [英] Tensorflow Dataset API restore Iterator after completing one epoch

查看:34
本文介绍了Tensorflow Dataset API 在完成一个 epoch 后恢复 Iterator的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有 190 个特征和标签,我的批量大小是 20,但经过 9 次迭代 tf.reshape 返回异常 要重塑的输入是一个具有 21 个值的张量,但请求的形状有60,我知道这是由于 Iterator.get_next() 造成的.我如何恢复我的迭代器,以便它再次从头开始提供批次服务?

I have 190 features and labels,My batch size is 20 but after 9 iterations tf.reshape is returning exception Input to reshape is a tensor with 21 values,but the requested shape has 60 and i know it is due to Iterator.get_next().How do i restore my Iterator so that it will again start serving batches from the beginning?

推荐答案

如果你想重启一个 tf.data.Iterator 从它的 Dataset 开始,考虑使用 initializable 迭代器,它有您可以运行以重新初始化迭代器的操作:

If you want to restart a tf.data.Iterator from the beginning of its Dataset, consider using an initializable iterator, which has an operation you can run to re-initialize the iterator:

dataset = ...  # A `tf.data.Dataset` instance.
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

train_op = ...  # Something that depends on `next_element`.

for _ in range(NUM_EPOCHS):
  # Initialize the iterator at the beginning of `dataset`.
  sess.run(iterator.initializer)

  # Loop over the examples in `iterator`, running `train_op`.
  try:
    while True:
      sess.run(train_op)

  except tf.errors.OutOfRangeError:  # Thrown at the end of the epoch.
    pass

  # Perform any per-epoch computations here.

有关不同类型的 Iterator 的更多详细信息,请参阅 tf.data 程序员指南.

For more details on the different kinds of Iterator, see the tf.data programmer's guide.

这篇关于Tensorflow Dataset API 在完成一个 epoch 后恢复 Iterator的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆