TensorFlow数据增强 [英] Tensorflow data augmentation
本文介绍了TensorFlow数据增强的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我要转换此Keras数据增强工作流:
datagen = ImageDataGenerator(
rescale=1./255,
rotation_range = 10,
horizontal_flip = True,
width_shift_range=0.1,
height_shift_range=0.1,
fill_mode = 'nearest')
以下是代码片段,但这两个函数都不起作用,因为它不支持批次维度!
import numpy as np
def augment(x, y):
x = tf.keras.preprocessing.image.random_shift(x, 0.1, 0.1)
x = tf.keras.preprocessing.image.random_rotation(
x, 10, row_axis=1, col_axis=2, channel_axis=0, fill_mode='nearest', cval=0.0,
interpolation_order=1)
return x, y
X = np.random.random(size=(256, 48, 48, 1))
y = np.random.randint(0, 7, size=(256,))
dataset = tf.data.Dataset.from_tensor_slices((X, y))
dataset = dataset.map(augment)
dataset = dataset.batch(16, drop_remainder=False)
dataset = dataset.prefetch(buffer_size=1)
推荐答案
我在运行代码时收到以下错误:AttributeError: 'Tensor' object has no attribute 'ndim'
。似乎不可能使用tf.data.Dataset
运行augment
函数,因为它不能处理张量。一种解决方法是将增强函数包装在tf.py_function:
import tensorflow as tf
import numpy as np
def augment(x, y):
x = x.numpy()
x = tf.keras.preprocessing.image.random_shift(x, 0.1, 0.1)
x = tf.keras.preprocessing.image.random_rotation(
x, 10, row_axis=1, col_axis=2, channel_axis=0, fill_mode='nearest', cval=0.0,
interpolation_order=1)
return x, y
X = np.random.random(size=(256, 48, 48, 1))
y = np.random.randint(0, 7, size=(256,))
dataset = tf.data.Dataset.from_tensor_slices((X, y))
dataset = dataset.map(
lambda x, y: tf.py_function(
func=augment,
inp=[x, y],
Tout=[tf.float32, tf.int64]))
dataset = dataset.batch(16, drop_remainder=False)
dataset = dataset.prefetch(buffer_size=1)
上面的代码应该运行时没有任何错误。如果您经常需要用tf.py_function
包装您的函数,那么编写一个修饰符可能会很方便(也很干净)。大概是这样的:
import tensorflow as tf
import numpy as np
def map_decorator(func):
def wrapper(*args):
return tf.py_function(
func=func,
inp=[*args],
Tout=[a.dtype for a in args])
return wrapper
@map_decorator
def augment(x, y):
x = x.numpy()
x = tf.keras.preprocessing.image.random_shift(x, 0.1, 0.1)
x = tf.keras.preprocessing.image.random_rotation(
x, 10, row_axis=1, col_axis=2, channel_axis=0, fill_mode='nearest', cval=0.0,
interpolation_order=1)
return x, y
X = np.random.random(size=(256, 48, 48, 1))
y = np.random.randint(0, 7, size=(256,))
dataset = tf.data.Dataset.from_tensor_slices((X, y))
dataset = dataset.map(augment)
dataset = dataset.batch(16, drop_remainder=False)
dataset = dataset.prefetch(buffer_size=1)
希望能有所帮助!
这篇关于TensorFlow数据增强的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文