使用 estimator api 避免 tf.data.Dataset.from_tensor_slices [英] Avoiding tf.data.Dataset.from_tensor_slices with estimator api

查看:40
本文介绍了使用 estimator api 避免 tf.data.Dataset.from_tensor_slices的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试找出将 dataset api 与 estimator api 一起使用的推荐方法.我在网上看到的一切都是这个的一些变体:

I'm am trying to figure out the recommended way to use the dataset api together with the estimator api. Everything I have seen online is some variation of this:

def train_input_fn():
   dataset = tf.data.Dataset.from_tensor_slices((features, labels))
   return dataset

然后可以将其传递给估算器的 train 函数:

which can then be passed to the estimator's train function:

 classifier.train(
    input_fn=train_input_fn,
    #...
 )

数据集指南警告说:

上面的代码片段会将特征和标签数组作为 tf.constant() 操作嵌入到您的 TensorFlow 图中.这适用于小数据集,但会浪费内存——因为数组的内容将被多次复制——并且可能会遇到 tf.GraphDef 协议缓冲区的 2GB 限制.

the above code snippet will embed the features and labels arrays in your TensorFlow graph as tf.constant() operations. This works well for a small dataset, but wastes memory---because the contents of the array will be copied multiple times---and can run into the 2GB limit for the tf.GraphDef protocol buffer.

然后描述一种方法,该方法涉及定义占位符,然后用 feed_dict 填充:

and then describes a method that involves defining placeholders which are then filled with the feed_dict:

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})

但如果您使用的是 estimator api,则您无需手动运行会话.那么如何将 dataset api 与估计器一起使用,同时避免与 from_tensor_slices() 相关的问题?

But if you're using the estimator api, you're not manually running the session. So how do you use the dataset api with estimators while avoiding the problems associated with from_tensor_slices()?

推荐答案

要使用可初始化或可重新初始化的迭代器,您必须创建一个继承自 tf.train.SessionRunHook 的类,该类可以在训练期间多次访问会话和评估步骤.

To use either initializable or reinitializable iterators, you must create a class that inherits from tf.train.SessionRunHook, which has access to the session at multiple times during training and evaluation steps.

然后您可以使用这个新类来初始化迭代器,您通常会在经典设置中执行此操作.您只需将这个新创建的钩子传递给训练/评估函数或正确的训练规范.

You can then use this new class to initialize the iterator has you would normally do in a classic setting. You simply need to pass this newly created hook to the training/evaluation functions or to the correct train spec.

这里有一个简单的例子,您可以根据自己的需要进行调整:

Here is quick example that you can adapt to your needs :

class IteratorInitializerHook(tf.train.SessionRunHook):
    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initializer_func = None # Will be set in the input_fn

    def after_create_session(self, session, coord):
        # Initialize the iterator with the data feed_dict
        self.iterator_initializer_func(session) 


def get_inputs(X, y):
    iterator_initializer_hook = IteratorInitializerHook()

    def input_fn():
        X_pl = tf.placeholder(X.dtype, X.shape)
        y_pl = tf.placeholder(y.dtype, y.shape)

        dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
        dataset = ...
        ...

        iterator = dataset.make_initializable_iterator()
        next_example, next_label = iterator.get_next()


        iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
                                                                                    feed_dict={X_pl: X, y_pl: y})

        return next_example, next_label

    return input_fn, iterator_initializer_hook

...

train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)

...

estimator.train(input_fn=train_input_fn,
                hooks=[train_iterator_initializer_hook]) # Don't forget to pass the hook !
estimator.evaluate(input_fn=test_input_fn,
                   hooks=[test_iterator_initializer_hook])

这篇关于使用 estimator api 避免 tf.data.Dataset.from_tensor_slices的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆