如何检查 tf.estimator.inputs.numpy_input_fn 的内容? [英] How do I inspect the contents of tf.estimator.inputs.numpy_input_fn?

查看:30
本文介绍了如何检查 tf.estimator.inputs.numpy_input_fn 的内容?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想在一组数据上反复训练我的张量流图,我认为 tf.estimator.inputs.numpy_input_fn 可能是我正在寻找的.我发现批量大小、重复次数、时期和迭代器之间的区别非常令人困惑,所以我开始尝试检查我的数据集的内容,试图弄清楚实际发生了什么.但是,每当我尝试这样做时,我的程序就会挂起.

I want to train my tensorflow graph over a set of data repeatedly, and I think tf.estimator.inputs.numpy_input_fn might be what I'm looking for. I find the distinction between batch sizes, repeats, epochs and iterators to be incredibly confusing, so I started trying to inspect the contents of my datasets to try to figure out what's actually going on. However, whenever I try to to do this my program just hangs.

这是我想出的最小测试用例来重现这个:

Here is the smallest test case I came up with to reproduce this:

import tensorflow as tf
import numpy

class TestMock(tf.test.TestCase):
    def test(self):
        inputs = numpy.array(range(10))
        targets = numpy.array(range(10,20))

        input_fn = tf.estimator.inputs.numpy_input_fn(
            x=inputs,
            y=targets,
            batch_size=1,
            num_epochs=2,
            shuffle=False)

        print input_fn()
        with self.test_session() as sess:
            # sess.run(input_fn()[0]) # it'll hang if I run this
            pass

if __name__ == '__main__':
    tf.test.main()

这个程序输出

(<tf.Tensor 'fifo_queue_DequeueUpTo:1' shape=(?,) dtype=int64>, <tf.Tensor 'fifo_queue_DequeueUpTo:2' shape=(?,) dtype=int64>)

这看起来很合理,但是一旦我尝试运行 sess.run 行,我的程序就会冻结,我必须终止该进程.我在这里做错了什么?

Which seems reasonable, but as soon as I try to run that sess.run line, my program freezes and I have to kill the process. What am I doing wrong here?

我想要做的是确保我提供给我的流程的数据实际上是我认为的那样,但如果没有检查数据的能力,我认为我无法做到这一点.

What I want to do is make sure that the data I'm feeding into my process is actually what I think it is, but I don't think I can do that without the ability to inspect the data.

推荐答案

从上面的打印语句我们可以推断 input_fn 返回 queue ops,我们需要运行它们使用 start_queue_runnersCoordinator:

From the above print statements we can infer that input_fn returns queue ops, we need to run them using start_queue_runners and Coordinator:

 features_op, labels_op = input_fn()
 with tf.Session() as sess:
     # initialise and start the queues.
     sess.run(tf.local_variables_initializer())

     coordinator = tf.train.Coordinator()
     _ = tf.train.start_queue_runners(coord=coordinator)

    print(sess.run([features_op, labels_op]))

    #[array([0]), array([10])]

这篇关于如何检查 tf.estimator.inputs.numpy_input_fn 的内容?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆