tf.data.Iterator.get_next(): 如何在 tf.while_loop 中前进? [英] tf.data.Iterator.get_next(): How to advance in tf.while_loop?

查看:47
本文介绍了tf.data.Iterator.get_next(): 如何在 tf.while_loop 中前进?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

目前我尝试在 Tensorflow while 循环中实现所有训练,但我在使用 Tensorflow 数据集 API 的迭代器时遇到了问题.

Currently I try to implement all training in a Tensorflow while loop, but I've got problems with the Tensorflow dataset API's Iterator.

通常,当调用 sess.run() 时,Iterator.get_next() 前进到下一个元素.但是,我需要在一次运行中前进到下一个元素.我该怎么做?

Usually, when calling sess.run(), Iterator.get_next() advances to the next element. However, I need to advance to the next element INSIDE one run. How do I do this?

下面的小例子说明了我的问题:

The following small example shows my problem:

import tensorflow as tf
import numpy as np


def for_loop(condition, modifier, body_op, idx=0):
    idx = tf.convert_to_tensor(idx)

    def body(i):
        with tf.control_dependencies([body_op(i)]):
            return [modifier(i)]

    # do the loop:
    loop = tf.while_loop(condition, body, [idx])
    return loop


x = np.arange(10)

data = tf.data.Dataset.from_tensor_slices(x)
data = data.repeat()

iterator = data.make_initializable_iterator()
smpl = iterator.get_next()

loop = for_loop(
    condition=lambda i: tf.less(i, 5),
    modifier=lambda i: tf.add(i, 1),
    body_op=lambda i: tf.Print(smpl, [smpl], message="This is sample: ")
)

sess = tf.InteractiveSession()
sess.run(iterator.initializer)
sess.run(loop)

输出:

This is sample: [0]
This is sample: [0]
This is sample: [0]
This is sample: [0]
This is sample: [0]

我总是得到完全相同的元素.

I always get exactly the same element.

推荐答案

每次想要迭代一次运行"时都需要调用 iterator.get_next().

You need to call iterator.get_next() every time you want to "iterate inside one run".

例如在您的玩具示例中,只需将您的 body_op 替换为:

For instance in your toy example, just replace your body_op with:

 body_op=lambda i: tf.Print(i, [iterator.get_next()], message="This is sample: ")
# This is sample: [0]
# This is sample: [1]
# This is sample: [2]
# This is sample: [3]
# This is sample: [4]

这篇关于tf.data.Iterator.get_next(): 如何在 tf.while_loop 中前进?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆