如何使用自定义函数在 TF 2 中使用 tf.data.Dataset.interleave()? [英] How to use tf.data.Dataset.interleave() in TF 2 with a custom function?

查看:53
本文介绍了如何使用自定义函数在 TF 2 中使用 tf.data.Dataset.interleave()?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用 TF 2.2,我正在尝试使用 tf.data 创建管道.

I'm using TF 2.2 and I'm trying to use tf.data to create a pipeline.

以下工作正常:

def load_image(filePath, label):

    print('Loading File: {}' + filePath)
    raw_bytes = tf.io.read_file(filePath)
    image = tf.io.decode_image(raw_bytes, expand_animations = False)

    return image, label

# TrainDS Pipeline
trainDS = getDataset()
trainDS = trainDS.shuffle(size['train'])
trainDS = trainDS.map(load_image, num_parallel_calls=AUTOTUNE)

for d in trainDS:
    print('Image: {} - Label: {}'.format(d[0], d[1]))

我想将 load_image()Dataset.interleave() 一起使用.然后我尝试了:

I would like to use the load_image() with the Dataset.interleave(). Then I tried:

# TrainDS Pipeline
trainDS = getDataset()
trainDS = trainDS.shuffle(size['train'])
trainDS = trainDS.interleave(lambda x, y: load_image_with_label(x, y), cycle_length=4)

for d in trainDS:
    print('Image: {} - Label: {}'.format(d[0], d[1]))

但我收到以下错误:

Exception has occurred: TypeError
`map_func` must return a `Dataset` object. Got <class 'tuple'>
  File "/data/dev/train_daninhas.py", line 44, in <module>
    trainDS = trainDS.interleave(lambda x, y: load_image_with_label(x, y), cycle_length=4)

如何调整我的代码以让 Dataset.interleave()load_image() 一起工作以并行读取图像?

How can I adapt my code to have the Dataset.interleave() working with the load_image() to read the images in parallel ?

推荐答案

如错误提示,您需要修改 load_image 使其返回一个 Dataset 对象,我在 tensorflow 2.2.0 中展示了一个带有两张图片的示例:

As the error suggests, you need to modify the load_image so that it return a Dataset object, I have shown an example with two images on how to go about doing it in tensorflow 2.2.0:

import tensorflow as tf
filenames = ["./img1.jpg", "./img2.jpg"]
labels = ["A", "B"]

def load_image(filePath, label):
    print('Loading File: {}' + filePath)
    raw_bytes = tf.io.read_file(filePath)
    image = tf.io.decode_image(raw_bytes, expand_animations = False)
    return tf.data.Dataset.from_tensors((image, label))

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.interleave(lambda x, y: load_image(x, y), cycle_length=4)

for i in dataset.as_numpy_iterator():
    image = i[0]
    label = i[1]
    print(image.shape)
    print(label.decode())

# (275, 183, 3)
# A
# (275, 183, 3)
# B

希望这会有所帮助!

这篇关于如何使用自定义函数在 TF 2 中使用 tf.data.Dataset.interleave()?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆