用于图像增强的 TF 数据集 API [英] TF Dataset API for Image augmentation

查看:20
本文介绍了用于图像增强的 TF 数据集 API的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用 tf 数据集 API 来读取图像及其标签.我喜欢对图像进行多次图像增强并增加我的训练数据大小.我现在所做的如下所示.

I am using tf Dataset API to read images and its labels. I like to do multiple image augmentations on images and increase my training data size. What i have done now is like below.

def flip(self, img, lbl):
  image = tf.image.flip_left_right(img)
  return image, lbl

def transpose(self, img, lbl):
  image = tf.image.transpose_image(img)
  return image, lbl

# just read and resize the image.
process_fn = lambda img, lbl: self.read_convert_image(img, lbl, self.args)
flip_fn = lambda img, lbl: self.flip(img,lbl)
transpose_fn = lambda img, lbl: self.transpose(img,lbl)

train_set = self.train_set.repeat()
train_set = train_set.shuffle(args.batch_size)
train_set = train_set.map(process_fn)

fliped_data = train_set.map(flip_fn)
transpose_data = train_set.map(transpose_fn)

train_set = train_set.concatenate(fliped_data)
train_set = train_set.concatenate(transpose_data)

train_set = train_set.batch(args.batch_size)
iterator = train_set.make_one_shot_iterator()

images, labels = iterator.get_next()

是否有更好的方法来进行多次增强.上述方法的问题是,如果我添加更多的增强功能,则需要许多 map 和 concatenate.

Is there a better way to do multiple augmentations. The problem with above approach is if i add more augmentation function , that many map and concatenate is required.

谢谢

推荐答案

如果你想自己做增强,不依赖 Keras 的 ImageDataGenerator 你可以创建一个类似 img_aug 的函数> 然后在您的模型或数据集 API 管道中使用它.下面的代码只是一个伪代码,但它展示了这个想法.您定义所有转换,然后您有一些通用阈值,高于该阈值,您可以应用转换并尝试将它们应用最多 X 次(在下面的代码中为 4)

If you want to do augmentations yourself, without relying on Keras's ImageDataGenerator you can create a function like img_aug and then use it in your model or in the Dataset API pipeline. The code below is just a pseudocode, but it shows the idea. You define all your transformations, then you have some generic threshold above which you apply a transformation and try to apply them up to X times (in the code below it is 4)

def img_aug(image):
  image = distorted_image

  def h_flip():
    return tf.image.flip_left_right(distorted_image)                
  def v_flip():
    return tf.image.flip_up_down(distorted_image)

  threshold = tf.constant(0.9, dtype=tf.float32)      

  def body(i, distorted_image):
    p_order = tf.random_uniform(shape=[2], minval=0., maxval=1., dtype=tf.float32)
    distorted_image = tf.case({                                      
                               tf.greater(p_order[0], threshold): h_flip,  
                               tf.greater(p_order[1], threshold): v_flip, 
                              }
                              ,default=identity, exclusive=False)
    return (i+1, distorted_image)

  def cond(i, *args):
    return i < 4 # max number of transformations

  parallel_iterations = 1
  tf.while_loop(cond, body, [0,distorted_image], 
                parallel_iterations=parallel_iterations)
  return distorted_image

这篇关于用于图像增强的 TF 数据集 API的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆