如何根据张量流中的某些谓词从队列中过滤张量? [英] How to filter tensor from queue based on some predicate in tensorflow?

查看:38
本文介绍了如何根据张量流中的某些谓词从队列中过滤张量?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

如何使用谓词函数过滤存储在队列中的数据?例如,假设我们有一个存储特征和标签张量的队列,我们​​只需要满足谓词的那些.我尝试了以下实现但没有成功:

How can I filter data stored in a queue using a predicate function? For example, let's say we have a queue that stores tensors of features and labels and we just need those that meet the predicate. I tried the following implementation without success:

feature, label = queue.dequeue()
if (predicate(feature, label)):
    enqueue_op = another_queue.enqueue(feature, label)

推荐答案

最直接的方法是出列一批,通过谓词测试运行它们,使用 tf.where 生成与谓词匹配的密集向量,并使用 tf.gather 收集结果,并将该批次入队.如果你想让它自动发生,你可以在第二个队列上启动一个队列运行器 - 最简单的方法是使用 tf.train.batch:

The most straightforward way to do this is to dequeue a batch, run them through the predicate test, use tf.where to produce a dense vector of the ones that match the predicate, and use tf.gather to collect the results, and enqueue that batch. If you want that to happen automatically, you can start a queue runner on the second queue - the easiest way to do that is to use tf.train.batch:

示例:

import numpy as np
import tensorflow as tf

a = tf.constant(np.array([5, 1, 9, 4, 7, 0], dtype=np.int32))

q = tf.FIFOQueue(6, dtypes=[tf.int32], shapes=[])
enqueue = q.enqueue_many([a])
dequeue = q.dequeue_many(6)
predmatch = tf.less(dequeue, [5])
selected_items = tf.reshape(tf.where(predmatch), [-1])
found = tf.gather(dequeue, selected_items)

secondqueue = tf.FIFOQueue(6, dtypes=[tf.int32], shapes=[])
enqueue2 = secondqueue.enqueue_many([found])
dequeue2 = secondqueue.dequeue_many(3) # XXX, hardcoded

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  sess.run(enqueue)  # Fill the first queue
  sess.run(enqueue2) # Filter, push into queue 2
  print sess.run(dequeue2) # Pop items off of queue2

谓词产生一个布尔向量;tf.where 生成真实值索引的密集向量,tf.gather 根据这些索引从原始张量中收集项目.

The predicate produces a boolean vector; the tf.where produces a dense vector of the indexes of the true values, and the tf.gather collects items from your original tensor based upon those indexes.

在这个例子中,很多东西都是硬编码的,你需要在现实中进行非硬编码,当然,但希望它显示了你正在尝试做的事情的结构(创建过滤管道).在实践中,您希望 QueueRunners 在那里保持自动搅拌.使用 tf.train.batch 对自动处理非常有用——请参阅 线程和队列 了解更多详情.

A lot of things are hardcoded in this example that you'd need to make not-hardcoded in reality, of course, but hopefully it shows the structure of what you're trying to do (create a filtering pipeline). In practice, you'd want QueueRunners on there to keep things churning automatically. Using tf.train.batch is very useful to handle that automatically -- see Threading and Queues for more detail.

这篇关于如何根据张量流中的某些谓词从队列中过滤张量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆