张量流多处理用于图像特征提取 [英] tensorflow multiprocessing for image feature extraction

查看:100
本文介绍了张量流多处理用于图像特征提取的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一些基本功能,可以接收图像的URL并通过VGG-16 CNN对其进行转换:

I have some basic function that takes in the URL of an image and transforms it via a VGG-16 CNN:

def convert_url(_id, url):   
  im = get_image(url)
  return _id, np.squeeze(sess.run(end_points['vgg_16/fc7'], feed_dict={input_tensor: im}))

我有大量的URL(〜60,000个),我想在这些URL上执行此功能.每次迭代花费的时间都超过一秒,这太慢了.我想通过并行使用多个进程来加快速度.无需担心共享状态,因此通常不会出现多线程陷阱.

I have a large set of URLs (~60,000) on which I'd like to perform this function. Each iteration takes longer than a second, which is far too slow. I'd like to speed it up by using multiple processes in parallel. There is no shared state to worry about, so the usual pitfalls of multithreading aren't an issue.

但是,我不确定完全如何使张量流与多处理程序包一起使用.我知道您无法将tensorflow session传递给Pool变量.因此,我尝试初始化session的多个实例:

However, I'm not exactly sure how to actually get tensorflow to work with the multiprocessing package. I know that you can't pass a tensorflow session to a Pool variable. So instead, I tried to initialize multiple instances of session:

def init():
  global sess;
  sess = tf.Session()

但是当我实际启动该过程时,它会无限期地挂起:

But when I actually launch the process, it just hangs indefinitely:

with Pool(processes=3,initializer=init) as pool:
  results = pool.starmap(convert_url, list(id_img_dict.items())[0:5])

请注意,tensorflow图是全局定义的.我认为这是正确的方法,但我不确定:

Note that the tensorflow graph is defined globally. I think that's the right way to do it but I'm not sure:

input_tensor = tf.placeholder(tf.float32, shape=(None,224,224,3), name='input_image')
scaled_input_tensor = tf.scalar_mul((1.0/255), input_tensor)
scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5)
scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0)

arg_scope = vgg_arg_scope()
with slim.arg_scope(arg_scope):
  _, end_points = vgg_16(scaled_input_tensor, is_training=False)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)

有人可以帮助我完成这项工作吗?非常有义务.

Can anyone help me get this working? Much obliged.

推荐答案

忘记python的常规多线程工具,并使用 tensorflow.contrib.data.Dataset .尝试以下类似的方法.

Forget about python's normal multithreading tools and use a tensorflow.contrib.data.Dataset. Try something like the following.

urls = ['img1.jpg', 'img2.jpg', ...]
batch_size = 16
n_batches = len(urls) // batch_size  # do something more elegant for remainder


def load_img(url):
    image = tf.read_file(url, name='image_data')
    image = tf.image.decode_jpeg(image, channels=3, name='image')
    return image


def preprocess(img_tensor):
    img_tensor = (tf.cast(img_tensor, tf.float32) / 255 - 0.5)*2
    img_tensor.set_shape((256, 256, 3))  # whatever shape
    return img_tensor


dataset = tf.contrib.data.Dataset.from_tensor_slices(urls)
dataset = dataset.map(load_img).map(preprocess)

preprocessed_images = dataset.batch(
    batch_size).make_one_shot_iterator().get_next()


arg_scope = vgg_arg_scope()
with slim.arg_scope(arg_scope):
    _, end_points = vgg_16(preprocessed_images, is_training=False)
    output = end_points['vgg_16/fc7']


results = []

with tf.Session() as sess:
    tf.train.Saver().restore(sess, checkpoint_file)
    for i in range(n_batches):
        batch_results = sess.run(output)
        results.extend(batch_results)
        print('Done batch %d / %d' % (i+1, n_batches))

这篇关于张量流多处理用于图像特征提取的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆