有没有办法在 Tensorflow 的另一个数据集中使用 tf.data.Dataset? [英] Is there a way to use tf.data.Dataset inside of another Dataset in Tensorflow?

查看:27
本文介绍了有没有办法在 Tensorflow 的另一个数据集中使用 tf.data.Dataset?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在做细分.每个训练样本都有多个带有分割掩码的图像.我正在尝试编写 input_fn 以将每个训练样本的所有掩码图像合并为一张.我计划使用两个 Datasets,一个遍历样本文件夹,另一个将所有掩码作为一个大批量读取,然后将它们合并到一个张量.

I'm doing segmentation. Each training sample have multiple images with segmentation masks. I'm trying to write input_fn to merge all mask images in to one for each training sample. I was planning on using two Datasets, one that iterates over samples folders and another that reads all masks as one large batch and then merges them to one tensor.

调用嵌套的 make_one_shot_iterator 时出现错误.我知道这种方法有点牵强,而且很可能不是为这种用途而设计的数据集.但是我应该如何解决这个问题,以避免使用 tf.py_func?

I'm getting an error when nested make_one_shot_iterator is called. I Know that this approach is a bit of a stretch and mostlikely datasets wheren't designed for such usage. But then how should I approach this problem so that I avoid using tf.py_func?

这是数据集的简化版本:

Here is a simplified version of the dataset:

def read_sample(sample_path):
    masks_ds = (tf.data.Dataset.
        list_files(sample_path+"/masks/*.png")
        .map(tf.read_file)
        .map(lambda x: tf.image.decode_image(x, channels=1))
        .batch(1024)) # maximum number of objects
    masks = masks_ds.make_one_shot_iterator().get_next()

    return tf.reduce_max(masks, axis=0)

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds.map(read_sample)
# ...
sample = ds.make_one_shot_iterator().get_next()
# ...

推荐答案

如果嵌套数据集只有一个元素,可以使用 tf.contrib.data.get_single_element() 在嵌套数据集上而不是创建迭代器:

If the nested dataset has only a single element, you can use tf.contrib.data.get_single_element() on the nested dataset instead of creating an iterator:

def read_sample(sample_path):
    masks_ds = (tf.data.Dataset.list_files(sample_path+"/masks/*.png")
                .map(tf.read_file)
                .map(lambda x: tf.image.decode_image(x, channels=1))
                .batch(1024)) # maximum number of objects
    masks = tf.contrib.data.get_single_element(masks_ds)
    return tf.reduce_max(masks, axis=0)

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds = ds.map(read_sample)
sample = ds.make_one_shot_iterator().get_next()

此外,您可以使用tf.data.Dataset.flat_map(), <代码>tf.data.Dataset.interleave()tf.contrib.data.parallel_interleave() 转换w 在函数内执行嵌套的Dataset 计算并将结果展平为单个<代码>数据集.例如,要获取单个 Dataset 中的所有样本:

In addition, you can use the tf.data.Dataset.flat_map(), tf.data.Dataset.interleave(), or tf.contrib.data.parallel_interleave() transformationw to perform a nested Dataset computation inside a function and flatten the result into a single Dataset. For example, to get all of the samples in a single Dataset:

def read_all_samples(sample_path):
    return (tf.data.Dataset.list_files(sample_path+"/masks/*.png")
            .map(tf.read_file)
            .map(lambda x: tf.image.decode_image(x, channels=1))
            .batch(1024)) # maximum number of objects

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds = ds.flat_map(read_all_samples)
sample = ds.make_one_shot_iterator().get_next()

这篇关于有没有办法在 Tensorflow 的另一个数据集中使用 tf.data.Dataset?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆