Tensorflow输入管道中的在线过采样 [英] Online oversampling in Tensorflow input pipeline

查看:201
本文介绍了Tensorflow输入管道中的在线过采样的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个输入管道,类似于卷积神经网络教程.我的数据集不平衡,我想使用少数族群过采样来尝试解决这个问题.理想情况下,我想在线"执行此操作,即,我不想在磁盘上复制数据样本.

I have an input pipeline similar to the one in the Convolutional Neural Network tutorial. My dataset is imbalanced and I want to use minority oversampling to try to deal with this. Ideally, I want to do this "online", i.e. I don't want to duplicate data samples on disk.

从本质上讲,我想要做的是基于标签重复单个示例(有一定可能性).我已经在Tensorflow中阅读了一些有关Control Flow的内容.看来tf.cond(pred, fn1, fn2)是可行的方式.我只是在努力寻找正确的参数化,因为fn1fn2将需要输出张量列表,列表的大小相同.

Essentially, what I want to do is duplicate individual examples (with some probability) based on the label. I have been reading a bit on Control Flow in Tensorflow. And it seems tf.cond(pred, fn1, fn2) is the way to go. I am just struggling to find the right parameterisation, since fn1 and fn2 would need to output lists of tensors, where the lists have the same size.

这大概是我到目前为止所拥有的:

This is roughly what I have so far:

image = image_preprocessing(image_buffer, bbox, False, thread_id)            
pred = tf.reshape(tf.equal(label, tf.convert_to_tensor([2])), [])
r_image = tf.cond(pred, lambda: [tf.identity(image), tf.identity(image)], lambda: [tf.identity(image),])
r_label = tf.cond(pred, lambda: [tf.identity(label), tf.identity(label)], lambda: [tf.identity(label),])

但是,这引起了我前面提到的错误:

However, this raises an error as I mentioned before:

ValueError: fn1 and fn2 must return the same number of results.

有什么想法吗?

P.S .:这是我的第一个堆栈溢出问题.对我的问题的任何反馈意见都将受到赞赏.

P.S.: this is my first Stack Overflow question. Any feedback on my question is appreciated.

推荐答案

进行了更多研究之后,我找到了想要解决的解决方案.我忘了提及的是,我的问题中提到的代码后面是一个批处理方法,例如batch()batch_join().

After doing a bit more research, I found a solution for what I wanted to do. What I forgot to mention is that the code mentioned in my question is followed by a batch method, such as batch() or batch_join().

这些函数采用一个参数,使您可以对各种批处理大小的张量进行分组,而不仅仅是单个示例的张量.参数为enqueue_many,应设置为True.

These functions take an argument that allows you to group tensors of various batch size rather than just tensors of a single example. The argument is enqueue_many and should be set to True.

以下代码对我有用:

for thread_id in range(num_preprocess_threads):

    # Parse a serialized Example proto to extract the image and metadata.
    image_buffer, label_index = parse_example_proto(
            example_serialized)

    image = image_preprocessing(image_buffer, bbox, False, thread_id)

    # Convert 3D tensor of shape [height, width, channels] to 
    # a 4D tensor of shape [batch_size, height, width, channels]
    image = tf.expand_dims(image, 0)

    # Define the boolean predicate to be true when the class label is 1
    pred = tf.equal(label_index, tf.convert_to_tensor([1]))
    pred = tf.reshape(pred, [])

    oversample_factor = 2
    r_image = tf.cond(pred, lambda: tf.concat(0, [image]*oversample_factor), lambda: image)
    r_label = tf.cond(pred, lambda: tf.concat(0, [label_index]*oversample_factor), lambda: label_index)
    images_and_labels.append([r_image, r_label])

images, label_batch = tf.train.shuffle_batch_join(
    images_and_labels,
    batch_size=batch_size,
    capacity=2 * num_preprocess_threads * batch_size,
    min_after_dequeue=1 * num_preprocess_threads * batch_size,
    enqueue_many=True)

这篇关于Tensorflow输入管道中的在线过采样的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆