使用 TensorFlow Dataset API 和 flat_map 的并行线程 [英] Parallel threads with TensorFlow Dataset API and flat_map
问题描述
我正在将我的 TensorFlow 代码从旧的队列界面更改为新的 DatasetAPI.使用旧接口,我可以为 tf.train.shuffle_batch
队列指定 num_threads
参数.但是,控制 Dataset API 中线程数量的唯一方法似乎是在 map
函数中使用 num_parallel_calls
参数.但是,我使用的是 flat_map
函数,它没有这样的参数.
I'm changing my TensorFlow code from the old queue interface to the new Dataset API. With the old interface I could specify the num_threads
argument to the tf.train.shuffle_batch
queue. However, the only way to control the amount of threads in the Dataset API seems to be in the map
function using the num_parallel_calls
argument. However, I'm using the flat_map
function instead, which doesn't have such an argument.
问题:有没有办法控制flat_map
函数的线程/进程数?或者有没有办法将 map
与 flat_map
结合使用,并且仍然指定并行调用的数量?
Question: Is there a way to control the number of threads/processes for the flat_map
function? Or is there are way to use map
in combination with flat_map
and still specify the number of parallel calls?
请注意,并行运行多个线程至关重要,因为我打算在数据进入队列之前在 CPU 上运行大量预处理.
Note that it is of crucial importance to run multiple threads in parallel, as I intend to run heavy pre-processing on the CPU before data enters the queue.
有两个(这里和这里) GitHub 上的相关帖子,但我认为他们没有回答这个问题.
There are two (here and here) related posts on GitHub, but I don't think they answer this question.
这是我的用例的最小代码示例,用于说明:
Here is a minimal code example of my use-case for illustration:
with tf.Graph().as_default():
data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
input_tensors = (data,)
def pre_processing_func(data_):
# normally I would do data-augmentation here
results = (tf.expand_dims(data_, axis=0),)
return tf.data.Dataset.from_tensor_slices(results)
dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
dataset = dataset_source.flat_map(pre_processing_func)
# do something with 'dataset'
推荐答案
据我所知,目前 flat_map
不提供并行选项.鉴于大部分计算是在 pre_processing_func
中完成的,您可以用作解决方法的是并行 map
调用,然后进行一些缓冲,然后使用 flat_map
使用恒等 lambda 函数调用,负责展平输出.
To the best of my knowledge, at the moment flat_map
does not offer parallelism options.
Given that the bulk of the computation is done in pre_processing_func
, what you might use as a workaround is a parallel map
call followed by some buffering, and then using a flat_map
call with an identity lambda function that takes care of flattening the output.
在代码中:
NUM_THREADS = 5
BUFFER_SIZE = 1000
def pre_processing_func(data_):
# data-augmentation here
# generate new samples starting from the sample `data_`
artificial_samples = generate_from_sample(data_)
return atificial_samples
dataset_source = (tf.data.Dataset.from_tensor_slices(input_tensors).
map(pre_processing_func, num_parallel_calls=NUM_THREADS).
prefetch(BUFFER_SIZE).
flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x)).
shuffle(BUFFER_SIZE)) # my addition, probably necessary though
注意(对我自己和任何试图了解管道的人):
由于 pre_processing_func
从初始样本开始生成任意数量的新样本(以 (?, 512)
形状的矩阵组织),flat_map
调用对于将所有生成的矩阵转换为包含单个样本的 Dataset
是必要的(因此在 lambda 中使用 tf.data.Dataset.from_tensor_slices(x)
)和然后将所有这些数据集展平为一个包含单个样本的大Dataset
.
Note (to myself and whoever will try to understand the pipeline):
Since pre_processing_func
generates an arbitrary number of new samples starting from the initial sample (organised in matrices of shape (?, 512)
), the flat_map
call is necessary to turn all the generated matrices into Dataset
s containing single samples (hence the tf.data.Dataset.from_tensor_slices(x)
in the lambda) and then flatten all these datasets into one big Dataset
containing individual samples.
.shuffle()
将数据集或生成的样本打包在一起可能是个好主意.
It's probably a good idea to .shuffle()
that dataset, or generated samples will be packed together.
这篇关于使用 TensorFlow Dataset API 和 flat_map 的并行线程的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!