如何在训练和推理中使用tf.Dataset设计? [英] How to use tf.Dataset design in both training and inferring?
问题描述
说,我们输入了 x
并标记了 y
:
Say, we have input x
and label y
:
iterator = tf.data.Iterator.from_structure((x_type, y_type), (x_shape, y_shape))
tf_x, tf_y = iterator.get_next()
现在我使用generate函数创建数据集:
Now I use generate function to create dataset:
def gen():
for ....: yield (x, y)
ds = tf.data.Dataset.from_generator(gen, (x_type, y_type), (x_shape, y_shape))
在我的图形中,我使用 tf_x
和 tf_y
进行培训,很好。但是,现在我想引用,在这里我没有标签 y
。我采取的一种解决方法是伪造ay(例如tf.zeros(y_shape)),然后使用占位符初始化迭代器。
In my graph, I use tf_x
and tf_y
to do training, that is fine. But now I want to do referring, where I don't have label y
. One workaround I made is to fake a y (like tf.zeros(y_shape)), then I use a placeholder to init the iterator.
x_placeholder = tf.placeholder(...)
y_placeholder = tf.placeholder(...)
ds = tf.data.Dataset.from_tensors((x_placeholder, y_placeholder))
ds_init_op = iterator.make_initializer(ds)
sess.run(ds_init_op, feed_dict={x_placeholder=x, y_placeholder=fake(y))})
我的问题是,有没有一种更清洁的方法?在推断时间内没有伪造 y
?
My question is, is there a cleaner way to do that? without fake a y
during inferring time?
更新:
我做了一点实验,好像有一个数据集操作解压缩
丢失了:
I experiment a little bit, looks like there is one dataset operation unzip
missing:
import numpy as np
import tensorflow as tf
x_type = tf.float32
y_type = tf.float32
x_shape = tf.TensorShape([None, 128])
y_shape = tf.TensorShape([None, 10])
x_shape_nobatch = tf.TensorShape([128])
y_shape_nobatch = tf.TensorShape([10])
iterator_x = tf.data.Iterator.from_structure((x_type,), (x_shape,))
iterator_y = tf.data.Iterator.from_structure((y_type,), (y_shape,))
def gen1():
for i in range(100):
yield np.random.randn(128)
ds1 = tf.data.Dataset.from_generator(gen1, (x_type,), (x_shape_nobatch,))
ds1 = ds1.batch(5)
ds1_init_op = iterator_x.make_initializer(ds1)
def gen2():
for i in range(80):
yield np.random.randn(128), np.random.randn(10)
ds2 = tf.data.Dataset.from_generator(gen2, (x_type, y_type), (x_shape_nobatch, y_shape_nobatch))
ds2 = ds2.batch(10)
# my ds2 has two tensors in one element, now the problem is
# how can I unzip this dataset so that I can apply them to iterator_x and iterator_y?
# such as:
ds2_x, ds2_y = tf.data.Dataset.unzip(ds2) #?? missing this unzip operation!
ds2_x_init_op = iterator_x.make_initializer(ds2_x)
ds2_y_init_op = iterator_y.make_initializer(ds2_y)
tf_x = iterator_x.get_next()
tf_y = iterator_y.get_next()
推荐答案
数据集API的目的是避免直接将值提供给会话(因为这会导致数据先流到客户端,然后流到设备)。
The purpose of datasets API is to avoid feeding the values directly to session (because that causes the data to flow first to the client, then to a device).
我看到的所有示例都使用数据集API还使用估算器API ,您可以在其中提供用于训练和推理的不同输入功能。
All examples I've seen that use datasets API also use estimator API, where you can provide different input functions for training and inference.
def train_dataset(data_dir):
"""Returns a tf.data.Dataset yielding (image, label) pairs for training."""
data = input_data.read_data_sets(data_dir, one_hot=True).train
return tf.data.Dataset.from_tensor_slices((data.images, data.labels))
def infer_dataset(data_dir):
"""Returns a tf.data.Dataset yielding images for inference."""
data = input_data.read_data_sets(data_dir, one_hot=True).test
return tf.data.Dataset.from_tensors((data.images,))
...
def train_input_fn():
dataset = train_dataset(FLAGS.data_dir)
dataset = dataset.shuffle(buffer_size=50000).batch(1024).repeat(10)
(images, labels) = dataset.make_one_shot_iterator().get_next()
return (images, labels)
mnist_classifier.train(input_fn=train_input_fn)
...
def infer_input_fn():
return infer_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next()
mnist_classifier.predict(input_fn=infer_input_fn)
这篇关于如何在训练和推理中使用tf.Dataset设计?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!