如何在 tf.estimator 的 input_fn 中使用 tf.data 的可初始化迭代器? [英] How to use tf.data's initializable iterators within a tf.estimator's input_fn?
问题描述
我想使用 管理我的训练tf.estimator.Estimator
但在将它与 tf.data
API.
I would like to manage my training with a tf.estimator.Estimator
but have some trouble to use it alongside the tf.data
API.
我有这样的事情:
def model_fn(features, labels, params, mode):
# Defines model's ops.
# Initializes with tf.train.Scaffold.
# Returns an tf.estimator.EstimatorSpec.
def input_fn():
dataset = tf.data.TextLineDataset("test.txt")
# map, shuffle, padded_batch, etc.
iterator = dataset.make_initializable_iterator()
return iterator.get_next()
estimator = tf.estimator.Estimator(model_fn)
estimator.train(input_fn)
由于我不能在我的用例中使用 make_one_shot_iterator
,我的问题是 input_fn
包含一个应该在 model_fn
中初始化的迭代器> (在这里,我使用 tf.train.Scaffold
初始化本地操作).
As I can't use a make_one_shot_iterator
for my use case, my issue is that input_fn
contains an iterator that should be initialized within model_fn
(here, I use tf.train.Scaffold
to initialize local ops).
另外,我知道我们不能只使用 input_fn = iterator.get_next
否则其他操作不会被添加到同一个图表中.
Also, I understood that we can't only use input_fn = iterator.get_next
otherwise the other ops will not be added to the same graph.
初始化迭代器的推荐方法是什么?
What is the recommended way to initialize the iterator?
推荐答案
从 TensorFlow 1.5 开始,可以让 input_fn
返回一个 tf.data.Dataset
,例如:
As of TensorFlow 1.5, it is possible to make input_fn
return a tf.data.Dataset
, e.g.:
def input_fn():
dataset = tf.data.TextLineDataset("test.txt")
# map, shuffle, padded_batch, etc.
return dataset
请参阅 c294fcfd.
对于以前的版本,您可以在 tf.GraphKeys.TABLE_INITIALIZERS
集合中添加迭代器的初始化程序并依赖默认初始化程序.
For previous versions, you can add the iterator's initializer in the tf.GraphKeys.TABLE_INITIALIZERS
collections and rely on the default initializer.
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
这篇关于如何在 tf.estimator 的 input_fn 中使用 tf.data 的可初始化迭代器?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!