将 input_fn 用于 tf.contrib.learn.Estimator 时设置 batch_size [英] setting batch_size when using input_fn for tf.contrib.learn.Estimator

查看:34
本文介绍了将 input_fn 用于 tf.contrib.learn.Estimator 时设置 batch_size的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在 TF 上使用高级估算器:

I am using the high-level Estimator on TF:

estim = tf.contrib.learn.Estimator(...)
estim.fit ( some_input )

如果 some_inputxybatch_size,代码运行但有警告;所以我尝试使用input_fn,并设法通过这个input_fn发送xy,但没有发送batch_size.没有找到任何例子.

If some_input has x, y, and batch_size, the codes run but with a warning; so I tried to use input_fn, and managed to send x, y through this input_fn, but not to send the batch_size. Didn't find any example for it.

谁能分享一个简单的例子,它使用 input_fn 作为 estim.fit/estim.evaluate 的输入,并使用 batch_size 也是?

Could anyone share a simple example that uses input_fn as input to the estim.fit / estim.evaluate, and uses batch_size as well?

我必须使用 tf.train.batch 吗?如果是这样,它如何合并到更高级别的实现中 (tf.layers) - 我不知道图形的 tf.Graph() 或会话?

Do I have to use tf.train.batch? If so, how does it merge into the higher-level implementation (tf.layers) - I don't know the graph's tf.Graph() or session?

以下是我收到的警告:

警告:tensorflow:来自/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/monitors.py:657:调用评估

WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/monitors.py:657: calling evaluate

(来自 tensorflow.contrib.learn.python.learn.estimators.estimator)与 y 已弃用,并将在 2016 年 12 月 1 日之后删除.

(from tensorflow.contrib.learn.python.learn.estimators.estimator) with y is deprecated and will be removed after 2016-12-01.

更新说明:Estimator 与 Scikit Learn 接口分离单独的类 SKCompat.参数 x、y 和 batch_size 只是在 SKCompat 类中可用,Estimator 将只接受 input_fn.

Instructions for updating: Estimator is decoupled from Scikit Learn interface by moving into separate class SKCompat. Arguments x, y and batch_size are only available in the SKCompat class, Estimator will only accept input_fn.

示例转换:

est = Estimator(...) -> est = SKCompat(Estimator(...))

est = Estimator(...) -> est = SKCompat(Estimator(...))

推荐答案

Roi 自己评论中提供的链接 确实很有帮助.由于我也在同一个问题上挣扎了一段时间,我想总结一下上面链接提供的答案作为参考:

The link provided in Roi's own comment was indeed really helpful. Since I was struggling with the same question as well for a while, I would like to summarize the answer provided by the link above as a reference:

def batched_input_fn(dataset_x, dataset_y, batch_size):
    def _input_fn():
        all_x = tf.constant(dataset_x, shape=dataset_x.shape, dtype=tf.float32)
        all_y = tf.constant(dataset_y, shape=dataset_y.shape, dtype=tf.float32)
        sliced_input = tf.train.slice_input_producer([all_x, all_y])
        return tf.train.batch(sliced_input, batch_size=batch_size)
    return _input_fn

然后可以像这个例子一样使用它(使用 TensorFlow v1.1):

This can then be used like this example (using TensorFlow v1.1):

model = CustomModel(FLAGS.learning_rate)
estimator= tf.estimator.Estimator(model_fn=model.build(), params=model.params())

estimator.train(input_fn=batched_input_fn(
       train.features, 
       train.labels,
       FLAGS.batch_size),
    steps=FLAGS.train_steps)

不幸的是,与手动馈送(使用 TensorFlows 低级 API)或与使用带有 train.shape[0] == batch_size 的整个数据集相比,这种方法大约慢 10 倍 并且根本不使用 train.sliced_input_producer()train.batch() .至少在我的机器上(仅限 CPU).我真的很想知道为什么这种方法这么慢.有什么想法吗?

Unfortunately, this approach is about 10x slower compared to manual feeding (using TensorFlows low-level API) or compared to using the whole dataset with train.shape[0] == batch_size and not using train.sliced_input_producer() and train.batch() at all. At least on my machine (CPU only). I'm really wondering why this approach is so slow. Any ideas?

我可以通过使用 num_threads > 1 作为 train.batch() 的参数来加快速度.在具有 2 个 CPU 的 VM 上,与默认的 num_threads=1 相比,我可以使用这种批处理机制将性能提高一倍.但是,它仍然比手动喂食慢 5 倍.但是在本机系统或将所有 CPU 内核用于输入管道和 GPU 用于模型计算的系统上,结果可能会有所不同.如果有人可以在评论中发布他的结果,那就太好了.

I could speed it up a bit by using num_threads > 1 as a parameter for train.batch(). On a VM with 2 CPUs, I'm able to double the performance using this batching mechanism compared to the default num_threads=1. But still, it is 5x slower than manual feeding. But results might be different on a native system or a system that uses all CPU cores for the input-pipeline and the GPU for the model computation. Would be great if somebody could post his results in the comments.

这篇关于将 input_fn 用于 tf.contrib.learn.Estimator 时设置 batch_size的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆